def atom_adj_mat(mol, conformer_i, **kwargs): """ OUTPUT IS ATOM_N x (adj_mat, tgt_atom, atomic_nos, dists ) This is really inefficient given that we explicitly return the same adj matrix for each atom, and index into it Adj mat is valence number * 2 """ MAX_ATOM_N = kwargs.get('MAX_ATOM_N', 64) atomic_nos, coords = get_nos_coords(mol, conformer_i) ATOM_N = len(atomic_nos) atomic_nos_pad, adj = mol_to_nums_adj(mol, MAX_ATOM_N) features = np.zeros((ATOM_N,), dtype=[('adj', np.uint8, (MAX_ATOM_N, MAX_ATOM_N)), ('my_idx', np.int), ('atomicno', np.uint8, MAX_ATOM_N), ('pos', np.float32, (MAX_ATOM_N, 3,))]) for atom_i in range(ATOM_N): vects = coords - coords[atom_i] features[atom_i]['adj'] = adj*2 features[atom_i]['my_idx'] = atom_i features[atom_i]['atomicno'] = atomic_nos_pad features[atom_i]['pos'][:ATOM_N] = vects return features
def feat_tensor_mol(mol, feat_distances=True, feat_r_pow=None, MAX_POW_M=2.0, conf_idx=0): """ Return matrix features for molecule """ res_mats = [] atomic_nos, coords = get_nos_coords(mol, conf_idx) ATOM_N = len(atomic_nos) if feat_distances: pos = coords a = pos.T.reshape(1, 3, -1) #turns Nx3 matx into 1x3xN b = (a - a.T) c = np.swapaxes(b, 2, 1) res_mats.append(c) #appends a new NxNx3 matrix if feat_r_pow is not None: pos = coords a = pos.T.reshape(1, 3, -1) b = (a - a.T)**2 c = np.swapaxes(b, 2, 1) d = np.sqrt(np.sum( c, axis=2)) #sum along third axis to get NxN matrix, sqrt entries e = (np.eye(d.shape[0]) + d)[:, :, np.newaxis] #add identity, return NxNx1 matx for p in feat_r_pow: e_pow = e**p #square each entry in NxNx1 matx if (e_pow > MAX_POW_M).any(): print("WARNING: max(M) = {:3.1f}".format(np.max(e_pow))) e_pow = np.minimum(e_pow, MAX_POW_M) res_mats.append(e_pow) if len(res_mats) > 0: M = np.concatenate(res_mats, 2) else: # Empty matrix M = np.zeros((ATOM_N, ATOM_N, 0), dtype=np.float32) return M
def advanced_atom_props(mol, conformer_i, **kwargs): import rdkit.Chem.rdPartialCharges pt = Chem.GetPeriodicTable() atomic_nos, coords = get_nos_coords(mol, conformer_i) mol = Chem.Mol(mol) Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, catchErrors=True) Chem.rdPartialCharges.ComputeGasteigerCharges(mol) ATOM_N = len(atomic_nos) out = np.zeros(ATOM_N, dtype=[('total_valence', np.int), ('aromatic', np.bool), ('hybridization', np.int), ('partial_charge', np.float32), ('formal_charge', np.float32), ('atomicno', np.int), ('r_covalent', np.float32), ('r_vanderwals', np.float32), ('default_valence', np.int), ('rings', np.bool, 5), ('pos', np.float32, 3)]) for i in range(mol.GetNumAtoms()): a = mol.GetAtomWithIdx(i) atomic_num = int(atomic_nos[i]) out[i]['total_valence'] = a.GetTotalValence() out[i]['aromatic'] = a.GetIsAromatic() out[i]['hybridization'] = a.GetHybridization() out[i]['partial_charge'] = a.GetProp('_GasteigerCharge') out[i]['formal_charge'] = a.GetFormalCharge() out[i]['atomicno'] = atomic_nos[i] out[i]['r_covalent'] =pt.GetRcovalent(atomic_num) out[i]['r_vanderwals'] = pt.GetRvdw(atomic_num) out[i]['default_valence'] = pt.GetDefaultValence(atomic_num) out[i]['rings'] = [a.IsInRingSize(r) for r in range(3, 8)] out[i]['pos'] = coords[i] return out
def feat_tensor_atom(mol, feat_atomicno = True, feat_pos=True, feat_atomicno_onehot=[1, 6, 7, 8, 9], feat_valence=True, aromatic=True, hybridization=True, partial_charge=True, formal_charge=True, r_covalent=True, r_vanderwals=True, default_valence=True, rings=False, total_valence_onehot=False, conf_idx = 0): """ Featurize a molecule on a per-atom basis feat_atomicno_onehot : list of atomic numbers Always assume using conf_idx unless otherwise passed Returns an (ATOM_N x feature) float32 tensor NOTE: Performs NO santization or cleanup of molecule, assumes all molecules have sanitization calculated ahead of time. """ pt = Chem.GetPeriodicTable() mol = Chem.Mol(mol) # copy molecule atomic_nos, coords = get_nos_coords(mol, conf_idx) #returns tuple of (array of atomic numbers, array of 3d atom coords) of conf_idx'th conformation ATOM_N = len(atomic_nos) #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, # catchErrors=True) if partial_charge: Chem.rdPartialCharges.ComputeGasteigerCharges(mol) atom_features = [] atom_types = [] for i in range(mol.GetNumAtoms()): #for every atom in molecule a = mol.GetAtomWithIdx(i) #gets atom of index i atomic_num = int(atomic_nos[i]) atom_feature = [] symbol = a.GetSymbol() if symbol == 'H': atom_types.append([1,0,0,0,0,0,0,0]) elif symbol == 'C': atom_types.append([0,1,0,0,0,0,0,0]) elif symbol == 'N': atom_types.append([0,0,1,0,0,0,0,0]) elif symbol == 'O': atom_types.append([0,0,0,1,0,0,0,0]) elif symbol == 'F': atom_types.append([0,0,0,0,1,0,0,0]) elif symbol == 'P': atom_types.append([0,0,0,0,0,1,0,0]) elif symbol == 'S': atom_types.append([0,0,0,0,0,0,1,0]) elif symbol == 'Cl': atom_types.append([0,0,0,0,0,0,0,1]) if feat_atomicno: atom_feature += [atomic_num] if feat_pos: atom_feature += coords[i].tolist() if feat_atomicno_onehot is not None : atom_feature += to_onehot(atomic_num, feat_atomicno_onehot) if feat_valence: atom_feature += [a.GetTotalValence()] if total_valence_onehot: atom_feature += to_onehot(a.GetTotalValence(), range(1, 7)) if aromatic: atom_feature += [a.GetIsAromatic()] if hybridization: atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS) if partial_charge: gc = float(a.GetProp('_GasteigerCharge')) #assert np.isfinite(gc) if not np.isfinite(gc): gc = 0.0 atom_feature += [gc] if formal_charge: atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1]) if r_covalent: atom_feature += [pt.GetRcovalent(atomic_num)] #radius of atom in covalent bond if r_vanderwals: atom_feature += [pt.GetRvdw(atomic_num)] if default_valence: atom_feature += to_onehot(pt.GetDefaultValence(atomic_num), range(1, 7)) if rings: atom_feature += [a.IsInRingSize(r) for r in range(3, 8)] # electronegativities = {1:2.20, 6:2.55, 7:3.04, 8:3.44, 9:3.98, 15:2.19, 16:2.58, 17:3.16} # atom_feature += [electronegativities[atomic_num]] atom_features.append(atom_feature) # z = [0]*len(atom_features[0]) # while len(atom_features) < 64: # atom_features.append(z) #atom features is a list of lists; inner list represents one atom and contains atom features # print('Atom Features Tensor Size: ' + str(torch.Tensor(atom_features).size())) # sys.stdout.flush() return torch.Tensor(atom_features), torch.Tensor(atom_types)
def feat_tensor_atom(mol, feat_atomicno=True, feat_pos=True, feat_atomicno_onehot=[1, 6, 7, 8, 9], feat_valence=True, aromatic=True, hybridization=True, partial_charge=True, formal_charge=True, r_covalent=True, r_vanderwals=True, default_valence=True, rings=False, total_valence_onehot=False, mmff_atom_types_onehot=False, max_ring_size=8, rad_electrons=False, chirality=False, assign_stereo=False, conf_idx=0): """ Featurize a molecule on a per-atom basis feat_atomicno_onehot : list of atomic numbers Always assume using conf_idx unless otherwise passed Returns an (ATOM_N x feature) float32 tensor NOTE: Performs NO santization or cleanup of molecule, assumes all molecules have sanitization calculated ahead of time. """ pt = Chem.GetPeriodicTable() mol = Chem.Mol(mol) # copy molecule atomic_nos, coords = get_nos_coords(mol, conf_idx) ATOM_N = len(atomic_nos) #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, # catchErrors=True) if partial_charge: Chem.rdPartialCharges.ComputeGasteigerCharges(mol) atom_features = [] if mmff_atom_types_onehot: mmff_p = Chem.rdForceFieldHelpers.MMFFGetMoleculeProperties(mol) if assign_stereo: Chem.rdmolops.AssignStereochemistryFrom3D(mol) for i in range(mol.GetNumAtoms()): a = mol.GetAtomWithIdx(i) atomic_num = int(atomic_nos[i]) atom_feature = [] if feat_atomicno: atom_feature += [atomic_num] if feat_pos: atom_feature += coords[i].tolist() if feat_atomicno_onehot is not None: atom_feature += to_onehot(atomic_num, feat_atomicno_onehot) if feat_valence: atom_feature += [a.GetTotalValence()] if total_valence_onehot: atom_feature += to_onehot(a.GetTotalValence(), range(1, 7)) if aromatic: atom_feature += [a.GetIsAromatic()] if hybridization: atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS) if partial_charge: gc = float(a.GetProp('_GasteigerCharge')) #assert np.isfinite(gc) if not np.isfinite(gc): gc = 0.0 atom_feature += [gc] if formal_charge: atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1]) if r_covalent: atom_feature += [pt.GetRcovalent(atomic_num)] if r_vanderwals: atom_feature += [pt.GetRvdw(atomic_num)] if default_valence: atom_feature += to_onehot(pt.GetDefaultValence(atomic_num), range(1, 7)) if rings: atom_feature += [ a.IsInRingSize(r) for r in range(3, max_ring_size) ] if rad_electrons: if a.GetNumRadicalElectrons() > 0: raise ValueError("RADICAL") if chirality: atom_feature += to_onehot(a.GetChiralTag(), CHI_TYPES) if mmff_atom_types_onehot: if mmff_p is None: atom_feature += [0] * len(MMFF94_ATOM_TYPES) else: atom_feature += to_onehot(mmff_p.GetMMFFAtomType(i), MMFF94_ATOM_TYPES) atom_features.append(atom_feature) return torch.Tensor(atom_features)
def feat_tensor_atom(mol, feat_atomicno=True, feat_pos=True, feat_atomicno_onehot=[1, 6, 7, 8, 9], feat_valence=True, aromatic=True, hybridization=True, partial_charge=True, formal_charge=True, r_covalent=True, r_vanderwals=True, default_valence=True, rings=False, total_valence_onehot=False, conf_idx=0): """ Featurize a molecule on a per-atom basis feat_atomicno_onehot : list of atomic numbers Always assume using conf_idx unless otherwise passed Returns an (ATOM_N x feature) float32 tensor NOTE: Performs NO santization or cleanup of molecule, assumes all molecules have sanitization calculated ahead of time. """ pt = Chem.GetPeriodicTable() mol = Chem.Mol(mol) # copy molecule atomic_nos, coords = get_nos_coords(mol, conf_idx) ATOM_N = len(atomic_nos) #Chem.SanitizeMol(mol, Chem.rdmolops.SanitizeFlags.SANITIZE_ALL, # catchErrors=True) if partial_charge: Chem.rdPartialCharges.ComputeGasteigerCharges(mol) atom_features = [] for i in range(mol.GetNumAtoms()): a = mol.GetAtomWithIdx(i) atomic_num = int(atomic_nos[i]) atom_feature = [] if feat_atomicno: atom_feature += [atomic_num] if feat_pos: atom_feature += coords[i].tolist() if feat_atomicno_onehot is not None: atom_feature += to_onehot(atomic_num, feat_atomicno_onehot) if feat_valence: atom_feature += [a.GetTotalValence()] if total_valence_onehot: atom_feature += to_onehot(a.GetTotalValence(), range(1, 7)) if aromatic: atom_feature += [a.GetIsAromatic()] if hybridization: atom_feature += to_onehot(a.GetHybridization(), HYBRIDIZATIONS) if partial_charge: gc = float(a.GetProp('_GasteigerCharge')) #assert np.isfinite(gc) if not np.isfinite(gc): gc = 0.0 atom_feature += [gc] if formal_charge: atom_feature += to_onehot(a.GetFormalCharge(), [-1, 0, 1]) if r_covalent: atom_feature += [pt.GetRcovalent(atomic_num)] if r_vanderwals: atom_feature += [pt.GetRvdw(atomic_num)] if default_valence: atom_feature += to_onehot(pt.GetDefaultValence(atomic_num), range(1, 7)) if rings: atom_feature += [a.IsInRingSize(r) for r in range(3, 8)] atom_features.append(atom_feature) return torch.Tensor(atom_features)
def __getitem__(self, idx): if self.frac_per_epoch < 1.0: # randomly get an index each time idx = np.random.randint(len(self.mols)) mol = self.mols[idx] pred_val = self.pred_vals[idx] whole_record = self.whole_records[idx] conf_idx = 0 if self.cache is not None and self.cache_key(idx, conf_idx) in self.cache: return self.cache[self.cache_key(idx, conf_idx)] # mol features f_mol = molecule_features.whole_molecule_features( whole_record, **self.mol_args) f_vect = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, **self.feat_vert_args) if self.combine_mol_vect: f_vect = torch.cat( [f_vect, f_mol.reshape(1, -1).expand(f_vect.shape[0], -1)], -1) # process extra data arguments for extra_data_config in self.extra_npy_filenames: filename = extra_data_config['filenames'][idx] combine_with = extra_data_config.get('combine_with', None) if combine_with == 'vert': npy_data = np.load(filename) npy_data_flatter = npy_data.reshape(f_vect.shape[0], -1) f_vect = torch.cat( [f_vect, torch.Tensor(npy_data_flatter)], dim=-1) elif combine_with is None: continue else: raise NotImplementedError( f"the combinewith {combine_with} not working yet") DATA_N = f_vect.shape[0] vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32) vect_feat[:DATA_N] = f_vect f_mat = molecule_features.feat_tensor_mol(mol, conf_idx=conf_idx, **self.feat_edge_args) if self.combine_mat_vect: MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1] else: MAT_CHAN = f_mat.shape[2] if MAT_CHAN == 0: # Dataloader can't handle tensors with empty dimensions MAT_CHAN = 1 mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN), dtype=np.float32) # do the padding mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat if self.combine_mat_vect == 'row': # row-major for i in range(DATA_N): mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect elif self.combine_mat_vect == 'col': # col-major for i in range(DATA_N): mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args) adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N)) adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad if self.combine_mat_feat_adj: adj = torch.cat([adj, torch.Tensor(mat_feat).permute(2, 0, 1)], 0) ### Simple one-hot encoding for reconstruction adj_oh_nopad = molecule_features.feat_mol_adj( mol, split_weights=[1.0, 1.5, 2.0, 3.0], edge_weighted=False, norm_adj=False, add_identity=False) adj_oh = torch.zeros((adj_oh_nopad.shape[0], self.MAX_N, self.MAX_N)) adj_oh[:, :adj_oh_nopad.shape[1], :adj_oh_nopad. shape[2]] = adj_oh_nopad ## per-edge features feat_edge_dict = edge_features.feat_edges(mol, ) # pad each of these edge_edge_nopad = feat_edge_dict['edge_edge'] edge_edge = torch.zeros( (edge_edge_nopad.shape[0], self.MAX_N, self.MAX_N)) # edge_edge[:, :edge_edge_nopad.shape[1], # :edge_edge_nopad.shape[2]] = torch.Tensor(edge_edge_nopad) edge_feat_nopad = feat_edge_dict['edge_feat'] edge_feat = torch.zeros((self.MAX_N, edge_feat_nopad.shape[1])) # edge_feat[:edge_feat_nopad.shape[0]] = torch.Tensor(edge_feat_nopad) edge_vert_nopad = feat_edge_dict['edge_vert'] edge_vert = torch.zeros( (edge_vert_nopad.shape[0], self.MAX_N, self.MAX_N)) # edge_vert[:, :edge_vert_nopad.shape[1], # :edge_vert_nopad.shape[2]] = torch.Tensor(edge_vert_nopad) atomicnos, coords = util.get_nos_coords(mol, conf_idx) coords_t = torch.zeros((self.MAX_N, 3)) coords_t[:len(coords), :] = torch.Tensor(coords) # create mask and preds pred_mask = np.zeros((self.MAX_N, self.PRED_N), dtype=np.float32) vals = np.ones((self.MAX_N, self.PRED_N), dtype=np.float32) * util.PERM_MISSING_VALUE #print(self.PRED_N, pred_val) if self.spect_assign: for pn in range(self.PRED_N): if len(pred_val) > 0: # when empty, there's nothing to predict atom_idx = [int(k) for k in pred_val[pn].keys()] obs_vals = [pred_val[pn][i] for i in atom_idx] # if self.shuffle_observations: # obs_vals = np.random.permutation(obs_vals) for k, v in zip(atom_idx, obs_vals): pred_mask[k, pn] = 1.0 vals[k, pn] = v else: if self.PRED_N > 1: raise NotImplementedError() vals[:] = util.PERM_MISSING_VALUE # sentinel value for k in pred_val[0][0]: pred_mask[k, 0] = 1 for vi, v in enumerate(pred_val[0][1]): vals[vi] = v # input mask input_mask = torch.zeros(self.MAX_N) input_mask[:DATA_N] = 1.0 v = { 'adj': adj, 'vect_feat': vect_feat, 'mat_feat': mat_feat, 'mol_feat': f_mol, 'vals': vals, 'adj_oh': adj_oh, 'pred_mask': pred_mask, 'coords': coords_t, 'input_mask': input_mask, 'input_idx': idx, 'edge_edge': edge_edge, 'edge_vert': edge_vert, 'edge_feat': edge_feat } ## add on extra args for ei, extra_data_config in enumerate(self.extra_npy_filenames): filename = extra_data_config['filenames'][idx] combine_with = extra_data_config.get('combine_with', None) if combine_with is None: ## this is an extra arg npy_data = np.load(filename) ## Zero pad npy_shape = list(npy_data.shape) npy_shape[0] = self.MAX_N t_pad = torch.zeros(npy_shape) t_pad[:npy_data.shape[0]] = torch.Tensor(npy_data) v[f'extra_data_{ei}'] = t_pad for k, kv in v.items(): assert np.isfinite(kv).all() if self.cache is not None: self.cache[self.cache_key(idx, conf_idx)] = v return v
def feat_tensor_mol(mol, feat_distances=False, feat_r_pow=None, mmff_opt_conf=False, is_in_ring=False, is_in_ring_size=None, MAX_POW_M=2.0, conf_idx=0, add_identity=False, edge_type_tuples=[], norm_mat=False, mat_power=1): """ Return matrix features for molecule """ res_mats = [] if mmff_opt_conf: Chem.AllChem.EmbedMolecule(mol) Chem.AllChem.MMFFOptimizeMolecule(mol) atomic_nos, coords = get_nos_coords(mol, conf_idx) ATOM_N = len(atomic_nos) if feat_distances: pos = coords a = pos.T.reshape(1, 3, -1) b = np.abs((a - a.T)) c = np.swapaxes(b, 2, 1) res_mats.append(c) if feat_r_pow is not None: pos = coords a = pos.T.reshape(1, 3, -1) b = (a - a.T)**2 c = np.swapaxes(b, 2, 1) d = np.sqrt(np.sum(c, axis=2)) e = (np.eye(d.shape[0]) + d)[:, :, np.newaxis] for p in feat_r_pow: e_pow = e**p if (e_pow > MAX_POW_M).any(): # print("WARNING: max(M) = {:3.1f}".format(np.max(e_pow))) e_pow = np.minimum(e_pow, MAX_POW_M) res_mats.append(e_pow) if len(edge_type_tuples) > 0: a = np.zeros((ATOM_N, ATOM_N, len(edge_type_tuples))) for et_i, et in enumerate(edge_type_tuples): for b in mol.GetBonds(): a_i = b.GetBeginAtomIdx() a_j = b.GetEndAtomIdx() if set(et) == set([atomic_nos[a_i], atomic_nos[a_j]]): a[a_i, a_j, et_i] = 1 a[a_j, a_i, et_i] = 1 res_mats.append(a) if is_in_ring: a = np.zeros((ATOM_N, ATOM_N, 1), dtype=np.float32) for b in mol.GetBonds(): a[b.GetBeginAtomIdx(), b.GetEndAtomIdx()] = 1 a[b.GetEndAtomIdx(), b.GetBeginAtomIdx()] = 1 res_mats.append(a) if is_in_ring_size is not None: for rs in is_in_ring_size: a = np.zeros((ATOM_N, ATOM_N, 1), dtype=np.float32) for b in mol.GetBonds(): if b.IsInRingSize(rs): a[b.GetBeginAtomIdx(), b.GetEndAtomIdx()] = 1 a[b.GetEndAtomIdx(), b.GetBeginAtomIdx()] = 1 res_mats.append(a) if len(res_mats) > 0: M = np.concatenate(res_mats, 2) else: # Empty matrix M = np.zeros((ATOM_N, ATOM_N, 0), dtype=np.float32) M = torch.Tensor(M).permute(2, 0, 1) if add_identity: M = M + torch.eye(ATOM_N).unsqueeze(0) if norm_mat: res = [] for i in range(M.shape[0]): a = M[i] D_12 = 1.0 / torch.sqrt(torch.sum(a, dim=0)) assert np.min(D_12.numpy()) > 0 s1 = D_12.reshape(ATOM_N, 1) s2 = D_12.reshape(1, ATOM_N) adj_i = s1 * a * s2 if isinstance(mat_power, list): for p in mat_power: adj_i_pow = torch.matrix_power(adj_i, p) res.append(adj_i_pow) else: if mat_power > 1: adj_i = torch.matrix_power(adj_i, mat_power) res.append(adj_i) M = torch.stack(res, 0) #print("M.shape=", M.shape) assert np.isfinite(M).all() return M.permute(1, 2, 0)