Beispiel #1
0
def atom_adj_mat(mol, conformer_i, **kwargs):
    """
    OUTPUT IS ATOM_N x (adj_mat, tgt_atom, atomic_nos, dists )
    
    This is really inefficient given that we explicitly return the same adj
    matrix for each atom, and index into it
    
    Adj mat is valence number * 2
    
    
    """
    
    MAX_ATOM_N = kwargs.get('MAX_ATOM_N', 64)
    atomic_nos, coords = get_nos_coords(mol, conformer_i)
    ATOM_N = len(atomic_nos)

    atomic_nos_pad, adj = mol_to_nums_adj(mol, MAX_ATOM_N)
    

    features = np.zeros((ATOM_N,), 
                    dtype=[('adj', np.uint8, (MAX_ATOM_N, MAX_ATOM_N)), 
                           ('my_idx', np.int), 
                           ('atomicno', np.uint8, MAX_ATOM_N), 
                           ('pos', np.float32, (MAX_ATOM_N, 3,))])

    
    
    for atom_i in range(ATOM_N):
        vects = coords - coords[atom_i]
        features[atom_i]['adj'] = adj*2
        features[atom_i]['my_idx'] =  atom_i
        features[atom_i]['atomicno'] = atomic_nos_pad
        features[atom_i]['pos'][:ATOM_N] = vects
    return features
Beispiel #2
0
def feat_tensor_mol(mol,
                    feat_distances=True,
                    feat_r_pow=None,
                    MAX_POW_M=2.0,
                    conf_idx=0):
    """
    Return matrix features for molecule
    
    """
    res_mats = []

    atomic_nos, coords = get_nos_coords(mol, conf_idx)
    ATOM_N = len(atomic_nos)

    if feat_distances:
        pos = coords
        a = pos.T.reshape(1, 3, -1)  #turns Nx3 matx into 1x3xN
        b = (a - a.T)
        c = np.swapaxes(b, 2, 1)
        res_mats.append(c)  #appends a new NxNx3 matrix
    if feat_r_pow is not None:
        pos = coords
        a = pos.T.reshape(1, 3, -1)
        b = (a - a.T)**2
        c = np.swapaxes(b, 2, 1)
        d = np.sqrt(np.sum(
            c, axis=2))  #sum along third axis to get NxN matrix, sqrt entries
        e = (np.eye(d.shape[0]) +
             d)[:, :, np.newaxis]  #add identity, return NxNx1 matx

        for p in feat_r_pow:
            e_pow = e**p  #square each entry in NxNx1 matx
            if (e_pow > MAX_POW_M).any():
                print("WARNING: max(M) = {:3.1f}".format(np.max(e_pow)))
                e_pow = np.minimum(e_pow, MAX_POW_M)

            res_mats.append(e_pow)
    if len(res_mats) > 0:
        M = np.concatenate(res_mats, 2)
    else:  # Empty matrix
        M = np.zeros((ATOM_N, ATOM_N, 0), dtype=np.float32)

    return M
Beispiel #3
0
def advanced_atom_props(mol, conformer_i, **kwargs):
    import rdkit.Chem.rdPartialCharges
    pt = Chem.GetPeriodicTable()
    atomic_nos, coords = get_nos_coords(mol, conformer_i)
    mol = Chem.Mol(mol)
    Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, 
                     catchErrors=True)
    Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
    ATOM_N = len(atomic_nos)
    out = np.zeros(ATOM_N, 
                   dtype=[('total_valence', np.int),  
                          ('aromatic', np.bool), 
                          ('hybridization', np.int), 
                          ('partial_charge', np.float32),
                          ('formal_charge', np.float32), 
                          ('atomicno', np.int), 
                          ('r_covalent', np.float32),
                          ('r_vanderwals', np.float32),
                          ('default_valence', np.int),
                          ('rings', np.bool, 5), 
                          ('pos', np.float32, 3)])
    
      
    for i in range(mol.GetNumAtoms()):
        a = mol.GetAtomWithIdx(i)
        atomic_num = int(atomic_nos[i])
        out[i]['total_valence'] = a.GetTotalValence()
        out[i]['aromatic'] = a.GetIsAromatic()
        out[i]['hybridization'] = a.GetHybridization()
        out[i]['partial_charge'] = a.GetProp('_GasteigerCharge')
        out[i]['formal_charge'] = a.GetFormalCharge()
        out[i]['atomicno'] = atomic_nos[i]
        out[i]['r_covalent'] =pt.GetRcovalent(atomic_num)
        out[i]['r_vanderwals'] =  pt.GetRvdw(atomic_num)
        out[i]['default_valence'] = pt.GetDefaultValence(atomic_num)
        out[i]['rings'] = [a.IsInRingSize(r) for r in range(3, 8)]
        out[i]['pos'] = coords[i]
                          
    return out
Beispiel #4
0
def feat_tensor_atom(mol, 
                     feat_atomicno = True, feat_pos=True, 
                     feat_atomicno_onehot=[1, 6, 7, 8, 9], 
                     feat_valence=True, aromatic=True, hybridization=True, 
                     partial_charge=True, formal_charge=True, r_covalent=True,
                     r_vanderwals=True, default_valence=True, rings=False, 
                     total_valence_onehot=False, 
                     conf_idx = 0):

    """
    Featurize a molecule on a per-atom basis
    feat_atomicno_onehot : list of atomic numbers

    Always assume using conf_idx unless otherwise passed

    Returns an (ATOM_N x feature) float32 tensor

    NOTE: Performs NO santization or cleanup of molecule, 
    assumes all molecules have sanitization calculated ahead
    of time. 

    """

    pt = Chem.GetPeriodicTable()
    mol = Chem.Mol(mol) # copy molecule

    atomic_nos, coords = get_nos_coords(mol, conf_idx) #returns tuple of (array of atomic numbers, array of 3d atom coords) of conf_idx'th conformation
    ATOM_N = len(atomic_nos)

    #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, 
    #                 catchErrors=True)


    if partial_charge:
        Chem.rdPartialCharges.ComputeGasteigerCharges(mol)

    atom_features = []
    atom_types = []
      
    for i in range(mol.GetNumAtoms()): #for every atom in molecule
        a = mol.GetAtomWithIdx(i) #gets atom of index i
        atomic_num = int(atomic_nos[i])
        atom_feature = []

        symbol = a.GetSymbol()

        if symbol == 'H':
            atom_types.append([1,0,0,0,0,0,0,0])
        elif symbol == 'C':
            atom_types.append([0,1,0,0,0,0,0,0])
        elif symbol == 'N':
            atom_types.append([0,0,1,0,0,0,0,0])
        elif symbol == 'O':
            atom_types.append([0,0,0,1,0,0,0,0])
        elif symbol == 'F':
            atom_types.append([0,0,0,0,1,0,0,0])
        elif symbol == 'P':
            atom_types.append([0,0,0,0,0,1,0,0])
        elif symbol == 'S':
            atom_types.append([0,0,0,0,0,0,1,0])
        elif symbol == 'Cl':
            atom_types.append([0,0,0,0,0,0,0,1])

        if feat_atomicno:
            atom_feature += [atomic_num]

        if feat_pos:
            atom_feature += coords[i].tolist()

        if feat_atomicno_onehot is not None : 
            atom_feature += to_onehot(atomic_num, feat_atomicno_onehot)
            
        if feat_valence:
            atom_feature += [a.GetTotalValence()]
        if total_valence_onehot:
            atom_feature +=  to_onehot(a.GetTotalValence(), range(1, 7))

        if aromatic:
            atom_feature += [a.GetIsAromatic()]

        if hybridization:
            atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS)

        if partial_charge:
            gc = float(a.GetProp('_GasteigerCharge'))
            #assert np.isfinite(gc)
            if not np.isfinite(gc):
                gc = 0.0
            atom_feature += [gc]

        if formal_charge:
            atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1])

        if r_covalent:
            atom_feature += [pt.GetRcovalent(atomic_num)] #radius of atom in covalent bond
        if r_vanderwals:
            atom_feature += [pt.GetRvdw(atomic_num)]
            

        if default_valence:
            atom_feature += to_onehot(pt.GetDefaultValence(atomic_num), range(1, 7))

        if rings:
           atom_feature +=  [a.IsInRingSize(r) for r in range(3, 8)]

        # electronegativities = {1:2.20, 6:2.55, 7:3.04, 8:3.44, 9:3.98, 15:2.19, 16:2.58, 17:3.16}
        # atom_feature += [electronegativities[atomic_num]]

        atom_features.append(atom_feature)

    # z = [0]*len(atom_features[0])
    # while len(atom_features) < 64:
    #     atom_features.append(z)

    #atom features is a list of lists; inner list represents one atom and contains atom features
    # print('Atom Features Tensor Size: ' + str(torch.Tensor(atom_features).size()))
    # sys.stdout.flush()
    return torch.Tensor(atom_features), torch.Tensor(atom_types)
Beispiel #5
0
def feat_tensor_atom(mol,
                     feat_atomicno=True,
                     feat_pos=True,
                     feat_atomicno_onehot=[1, 6, 7, 8, 9],
                     feat_valence=True,
                     aromatic=True,
                     hybridization=True,
                     partial_charge=True,
                     formal_charge=True,
                     r_covalent=True,
                     r_vanderwals=True,
                     default_valence=True,
                     rings=False,
                     total_valence_onehot=False,
                     mmff_atom_types_onehot=False,
                     max_ring_size=8,
                     rad_electrons=False,
                     chirality=False,
                     assign_stereo=False,
                     conf_idx=0):
    """
    Featurize a molecule on a per-atom basis
    feat_atomicno_onehot : list of atomic numbers

    Always assume using conf_idx unless otherwise passed

    Returns an (ATOM_N x feature) float32 tensor

    NOTE: Performs NO santization or cleanup of molecule, 
    assumes all molecules have sanitization calculated ahead
    of time. 

    """

    pt = Chem.GetPeriodicTable()
    mol = Chem.Mol(mol)  # copy molecule

    atomic_nos, coords = get_nos_coords(mol, conf_idx)
    ATOM_N = len(atomic_nos)

    #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL,
    #                 catchErrors=True)

    if partial_charge:
        Chem.rdPartialCharges.ComputeGasteigerCharges(mol)

    atom_features = []
    if mmff_atom_types_onehot:
        mmff_p = Chem.rdForceFieldHelpers.MMFFGetMoleculeProperties(mol)

    if assign_stereo:
        Chem.rdmolops.AssignStereochemistryFrom3D(mol)

    for i in range(mol.GetNumAtoms()):
        a = mol.GetAtomWithIdx(i)
        atomic_num = int(atomic_nos[i])
        atom_feature = []

        if feat_atomicno:
            atom_feature += [atomic_num]

        if feat_pos:
            atom_feature += coords[i].tolist()

        if feat_atomicno_onehot is not None:
            atom_feature += to_onehot(atomic_num, feat_atomicno_onehot)

        if feat_valence:
            atom_feature += [a.GetTotalValence()]
        if total_valence_onehot:
            atom_feature += to_onehot(a.GetTotalValence(), range(1, 7))

        if aromatic:
            atom_feature += [a.GetIsAromatic()]

        if hybridization:
            atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS)

        if partial_charge:
            gc = float(a.GetProp('_GasteigerCharge'))
            #assert np.isfinite(gc)
            if not np.isfinite(gc):
                gc = 0.0
            atom_feature += [gc]

        if formal_charge:
            atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1])

        if r_covalent:
            atom_feature += [pt.GetRcovalent(atomic_num)]
        if r_vanderwals:
            atom_feature += [pt.GetRvdw(atomic_num)]

        if default_valence:
            atom_feature += to_onehot(pt.GetDefaultValence(atomic_num),
                                      range(1, 7))

        if rings:
            atom_feature += [
                a.IsInRingSize(r) for r in range(3, max_ring_size)
            ]

        if rad_electrons:
            if a.GetNumRadicalElectrons() > 0:
                raise ValueError("RADICAL")

        if chirality:
            atom_feature += to_onehot(a.GetChiralTag(), CHI_TYPES)

        if mmff_atom_types_onehot:
            if mmff_p is None:
                atom_feature += [0] * len(MMFF94_ATOM_TYPES)
            else:
                atom_feature += to_onehot(mmff_p.GetMMFFAtomType(i),
                                          MMFF94_ATOM_TYPES)
        atom_features.append(atom_feature)

    return torch.Tensor(atom_features)
Beispiel #6
0
def feat_tensor_atom(mol,
                     feat_atomicno=True,
                     feat_pos=True,
                     feat_atomicno_onehot=[1, 6, 7, 8, 9],
                     feat_valence=True,
                     aromatic=True,
                     hybridization=True,
                     partial_charge=True,
                     formal_charge=True,
                     r_covalent=True,
                     r_vanderwals=True,
                     default_valence=True,
                     rings=False,
                     total_valence_onehot=False,
                     conf_idx=0):
    """
    Featurize a molecule on a per-atom basis
    feat_atomicno_onehot : list of atomic numbers

    Always assume using conf_idx unless otherwise passed

    Returns an (ATOM_N x feature) float32 tensor

    NOTE: Performs NO santization or cleanup of molecule, 
    assumes all molecules have sanitization calculated ahead
    of time. 

    """

    pt = Chem.GetPeriodicTable()
    mol = Chem.Mol(mol)  # copy molecule

    atomic_nos, coords = get_nos_coords(mol, conf_idx)
    ATOM_N = len(atomic_nos)

    #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL,
    #                 catchErrors=True)

    if partial_charge:
        Chem.rdPartialCharges.ComputeGasteigerCharges(mol)

    atom_features = []

    for i in range(mol.GetNumAtoms()):
        a = mol.GetAtomWithIdx(i)
        atomic_num = int(atomic_nos[i])
        atom_feature = []

        if feat_atomicno:
            atom_feature += [atomic_num]

        if feat_pos:
            atom_feature += coords[i].tolist()

        if feat_atomicno_onehot is not None:
            atom_feature += to_onehot(atomic_num, feat_atomicno_onehot)

        if feat_valence:
            atom_feature += [a.GetTotalValence()]
        if total_valence_onehot:
            atom_feature += to_onehot(a.GetTotalValence(), range(1, 7))

        if aromatic:
            atom_feature += [a.GetIsAromatic()]

        if hybridization:
            atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS)

        if partial_charge:
            gc = float(a.GetProp('_GasteigerCharge'))
            #assert np.isfinite(gc)
            if not np.isfinite(gc):
                gc = 0.0
            atom_feature += [gc]

        if formal_charge:
            atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1])

        if r_covalent:
            atom_feature += [pt.GetRcovalent(atomic_num)]
        if r_vanderwals:
            atom_feature += [pt.GetRvdw(atomic_num)]

        if default_valence:
            atom_feature += to_onehot(pt.GetDefaultValence(atomic_num),
                                      range(1, 7))

        if rings:
            atom_feature += [a.IsInRingSize(r) for r in range(3, 8)]

        atom_features.append(atom_feature)

    return torch.Tensor(atom_features)
Beispiel #7
0
    def __getitem__(self, idx):
        if self.frac_per_epoch < 1.0:
            # randomly get an index each time
            idx = np.random.randint(len(self.mols))

        mol = self.mols[idx]
        pred_val = self.pred_vals[idx]
        whole_record = self.whole_records[idx]

        conf_idx = 0

        if self.cache is not None and self.cache_key(idx,
                                                     conf_idx) in self.cache:
            return self.cache[self.cache_key(idx, conf_idx)]

        # mol features
        f_mol = molecule_features.whole_molecule_features(
            whole_record, **self.mol_args)

        f_vect = atom_features.feat_tensor_atom(mol,
                                                conf_idx=conf_idx,
                                                **self.feat_vert_args)
        if self.combine_mol_vect:
            f_vect = torch.cat(
                [f_vect,
                 f_mol.reshape(1, -1).expand(f_vect.shape[0], -1)], -1)
        # process extra data arguments
        for extra_data_config in self.extra_npy_filenames:
            filename = extra_data_config['filenames'][idx]
            combine_with = extra_data_config.get('combine_with', None)
            if combine_with == 'vert':
                npy_data = np.load(filename)
                npy_data_flatter = npy_data.reshape(f_vect.shape[0], -1)
                f_vect = torch.cat(
                    [f_vect, torch.Tensor(npy_data_flatter)], dim=-1)
            elif combine_with is None:
                continue
            else:

                raise NotImplementedError(
                    f"the combinewith {combine_with} not working yet")
        DATA_N = f_vect.shape[0]

        vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32)
        vect_feat[:DATA_N] = f_vect

        f_mat = molecule_features.feat_tensor_mol(mol,
                                                  conf_idx=conf_idx,
                                                  **self.feat_edge_args)

        if self.combine_mat_vect:
            MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1]
        else:
            MAT_CHAN = f_mat.shape[2]
        if MAT_CHAN == 0:  # Dataloader can't handle tensors with empty dimensions
            MAT_CHAN = 1
        mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN),
                            dtype=np.float32)
        # do the padding
        mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat

        if self.combine_mat_vect == 'row':
            # row-major
            for i in range(DATA_N):
                mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect
        elif self.combine_mat_vect == 'col':
            # col-major
            for i in range(DATA_N):
                mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect

        adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args)
        adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N))
        adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad

        if self.combine_mat_feat_adj:
            adj = torch.cat([adj, torch.Tensor(mat_feat).permute(2, 0, 1)], 0)

        ### Simple one-hot encoding for reconstruction
        adj_oh_nopad = molecule_features.feat_mol_adj(
            mol,
            split_weights=[1.0, 1.5, 2.0, 3.0],
            edge_weighted=False,
            norm_adj=False,
            add_identity=False)

        adj_oh = torch.zeros((adj_oh_nopad.shape[0], self.MAX_N, self.MAX_N))
        adj_oh[:, :adj_oh_nopad.shape[1], :adj_oh_nopad.
               shape[2]] = adj_oh_nopad

        ## per-edge features
        feat_edge_dict = edge_features.feat_edges(mol, )

        # pad each of these
        edge_edge_nopad = feat_edge_dict['edge_edge']
        edge_edge = torch.zeros(
            (edge_edge_nopad.shape[0], self.MAX_N, self.MAX_N))

        # edge_edge[:, :edge_edge_nopad.shape[1],
        #              :edge_edge_nopad.shape[2]] = torch.Tensor(edge_edge_nopad)

        edge_feat_nopad = feat_edge_dict['edge_feat']
        edge_feat = torch.zeros((self.MAX_N, edge_feat_nopad.shape[1]))
        # edge_feat[:edge_feat_nopad.shape[0]] = torch.Tensor(edge_feat_nopad)

        edge_vert_nopad = feat_edge_dict['edge_vert']
        edge_vert = torch.zeros(
            (edge_vert_nopad.shape[0], self.MAX_N, self.MAX_N))
        # edge_vert[:, :edge_vert_nopad.shape[1],
        #              :edge_vert_nopad.shape[2]] = torch.Tensor(edge_vert_nopad)

        atomicnos, coords = util.get_nos_coords(mol, conf_idx)
        coords_t = torch.zeros((self.MAX_N, 3))
        coords_t[:len(coords), :] = torch.Tensor(coords)

        # create mask and preds

        pred_mask = np.zeros((self.MAX_N, self.PRED_N), dtype=np.float32)
        vals = np.ones((self.MAX_N, self.PRED_N),
                       dtype=np.float32) * util.PERM_MISSING_VALUE
        #print(self.PRED_N, pred_val)
        if self.spect_assign:
            for pn in range(self.PRED_N):
                if len(pred_val) > 0:  # when empty, there's nothing to predict
                    atom_idx = [int(k) for k in pred_val[pn].keys()]
                    obs_vals = [pred_val[pn][i] for i in atom_idx]
                    # if self.shuffle_observations:
                    #     obs_vals = np.random.permutation(obs_vals)
                    for k, v in zip(atom_idx, obs_vals):
                        pred_mask[k, pn] = 1.0
                        vals[k, pn] = v
        else:
            if self.PRED_N > 1:
                raise NotImplementedError()
            vals[:] = util.PERM_MISSING_VALUE  # sentinel value
            for k in pred_val[0][0]:
                pred_mask[k, 0] = 1
            for vi, v in enumerate(pred_val[0][1]):
                vals[vi] = v

        # input mask
        input_mask = torch.zeros(self.MAX_N)
        input_mask[:DATA_N] = 1.0

        v = {
            'adj': adj,
            'vect_feat': vect_feat,
            'mat_feat': mat_feat,
            'mol_feat': f_mol,
            'vals': vals,
            'adj_oh': adj_oh,
            'pred_mask': pred_mask,
            'coords': coords_t,
            'input_mask': input_mask,
            'input_idx': idx,
            'edge_edge': edge_edge,
            'edge_vert': edge_vert,
            'edge_feat': edge_feat
        }

        ## add on extra args
        for ei, extra_data_config in enumerate(self.extra_npy_filenames):
            filename = extra_data_config['filenames'][idx]
            combine_with = extra_data_config.get('combine_with', None)
            if combine_with is None:
                ## this is an extra arg
                npy_data = np.load(filename)

                ## Zero pad
                npy_shape = list(npy_data.shape)
                npy_shape[0] = self.MAX_N
                t_pad = torch.zeros(npy_shape)
                t_pad[:npy_data.shape[0]] = torch.Tensor(npy_data)

                v[f'extra_data_{ei}'] = t_pad

        for k, kv in v.items():
            assert np.isfinite(kv).all()
        if self.cache is not None:
            self.cache[self.cache_key(idx, conf_idx)] = v

        return v
Beispiel #8
0
def feat_tensor_mol(mol,
                    feat_distances=False,
                    feat_r_pow=None,
                    mmff_opt_conf=False,
                    is_in_ring=False,
                    is_in_ring_size=None,
                    MAX_POW_M=2.0,
                    conf_idx=0,
                    add_identity=False,
                    edge_type_tuples=[],
                    norm_mat=False,
                    mat_power=1):
    """
    Return matrix features for molecule
    
    """
    res_mats = []
    if mmff_opt_conf:
        Chem.AllChem.EmbedMolecule(mol)
        Chem.AllChem.MMFFOptimizeMolecule(mol)

    atomic_nos, coords = get_nos_coords(mol, conf_idx)
    ATOM_N = len(atomic_nos)

    if feat_distances:
        pos = coords
        a = pos.T.reshape(1, 3, -1)
        b = np.abs((a - a.T))
        c = np.swapaxes(b, 2, 1)
        res_mats.append(c)
    if feat_r_pow is not None:
        pos = coords
        a = pos.T.reshape(1, 3, -1)
        b = (a - a.T)**2
        c = np.swapaxes(b, 2, 1)
        d = np.sqrt(np.sum(c, axis=2))
        e = (np.eye(d.shape[0]) + d)[:, :, np.newaxis]

        for p in feat_r_pow:
            e_pow = e**p
            if (e_pow > MAX_POW_M).any():
                # print("WARNING: max(M) = {:3.1f}".format(np.max(e_pow)))
                e_pow = np.minimum(e_pow, MAX_POW_M)

            res_mats.append(e_pow)

    if len(edge_type_tuples) > 0:
        a = np.zeros((ATOM_N, ATOM_N, len(edge_type_tuples)))
        for et_i, et in enumerate(edge_type_tuples):
            for b in mol.GetBonds():
                a_i = b.GetBeginAtomIdx()
                a_j = b.GetEndAtomIdx()
                if set(et) == set([atomic_nos[a_i], atomic_nos[a_j]]):
                    a[a_i, a_j, et_i] = 1
                    a[a_j, a_i, et_i] = 1
        res_mats.append(a)

    if is_in_ring:
        a = np.zeros((ATOM_N, ATOM_N, 1), dtype=np.float32)
        for b in mol.GetBonds():
            a[b.GetBeginAtomIdx(), b.GetEndAtomIdx()] = 1
            a[b.GetEndAtomIdx(), b.GetBeginAtomIdx()] = 1
        res_mats.append(a)

    if is_in_ring_size is not None:
        for rs in is_in_ring_size:
            a = np.zeros((ATOM_N, ATOM_N, 1), dtype=np.float32)
            for b in mol.GetBonds():
                if b.IsInRingSize(rs):
                    a[b.GetBeginAtomIdx(), b.GetEndAtomIdx()] = 1
                    a[b.GetEndAtomIdx(), b.GetBeginAtomIdx()] = 1
            res_mats.append(a)

    if len(res_mats) > 0:
        M = np.concatenate(res_mats, 2)
    else:  # Empty matrix
        M = np.zeros((ATOM_N, ATOM_N, 0), dtype=np.float32)

    M = torch.Tensor(M).permute(2, 0, 1)

    if add_identity:
        M = M + torch.eye(ATOM_N).unsqueeze(0)

    if norm_mat:
        res = []
        for i in range(M.shape[0]):
            a = M[i]
            D_12 = 1.0 / torch.sqrt(torch.sum(a, dim=0))
            assert np.min(D_12.numpy()) > 0
            s1 = D_12.reshape(ATOM_N, 1)
            s2 = D_12.reshape(1, ATOM_N)
            adj_i = s1 * a * s2

            if isinstance(mat_power, list):
                for p in mat_power:
                    adj_i_pow = torch.matrix_power(adj_i, p)

                    res.append(adj_i_pow)

            else:
                if mat_power > 1:
                    adj_i = torch.matrix_power(adj_i, mat_power)

                res.append(adj_i)
        M = torch.stack(res, 0)
    #print("M.shape=", M.shape)
    assert np.isfinite(M).all()
    return M.permute(1, 2, 0)