def __getitem__(self, idx): mol = self.mols[idx] NUM_ATOMS = mol.GetNumAtoms() #each value in pred_vals is a dictionary containing key value pairs of atom numbers and chem shift vals # pred_val returns the dictionary for the appropriate index (molecule) pred_val = self.pred_vals[idx] conf_idx = np.random.randint( mol.GetNumConformers() ) #returns random number between 0 and numconformers; this is used to randomly select a conformation #f_vect is a 2d tensor containing atom features; inner tensors represent one atom and contain features #shape of f_vect is num_atomsxnum_features # f_vect, atom_types = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, **self.feat_vert_args) f_vect, atom_types = atom_features.feat_tensor_atom( mol, conf_idx=conf_idx, **self.feat_vert_args) edge_index, edge_attr = molecule_features.get_edge_attr_and_ind(mol) #pred_val is a dictionary containing key value pairs of atom numbers and chem shift vals target = np.zeros((NUM_ATOMS, 1), dtype=np.float32) #64x1 mask = np.zeros((NUM_ATOMS, 1), dtype=np.float32) #64x1 for pn in range(self.PRED_N): for k, v in pred_val[pn].items(): target[int(k), pn] = v mask[int(k), pn] = 1.0 mask = torch.FloatTensor(mask).flatten() target = torch.FloatTensor(target).flatten() v = (f_vect, atom_types, edge_index, edge_attr, mask, target) # v = (f_vect, edge_index, edge_attr, mask, target) return v
def __getitem__(self, idx): mol = self.mols[idx] pred_val = self.pred_vals[idx] conf_idx = np.random.randint(mol.GetNumConformers()) if self.cache_key(idx, conf_idx) in self.cache: return self.mask_sel(self.cache[self.cache_key(idx, conf_idx)]) f_vect = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, **self.feat_vert_args) DATA_N = f_vect.shape[0] vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32) vect_feat[:DATA_N] = f_vect f_mat = molecule_features.feat_tensor_mol(mol, conf_idx=conf_idx, **self.feat_edge_args) if self.combine_mat_vect: MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1] else: MAT_CHAN = f_mat.shape[2] mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN), dtype=np.float32) # do the padding mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat if self.combine_mat_vect == 'row': # row-major for i in range(DATA_N): mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect elif self.combine_mat_vect == 'col': # col-major for i in range(DATA_N): mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args) adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N)) adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad # create mask and preds mask = np.zeros(self.MAX_N, dtype=np.float32) vals = np.zeros(self.MAX_N, dtype=np.float32) for k, v in pred_val.items(): mask[k] = 1.0 vals[k] = v v = ( adj, vect_feat, mat_feat, vals, mask, ) self.cache[self.cache_key(idx, conf_idx)] = v return self.mask_sel(v)
def __getitem__(self, idx): mol = self.mols[idx] vect_pred_val = self.pred_vec_vals[idx] mat_pred_val = self.pred_mat_vals[idx] conf_idx = np.random.randint(mol.GetNumConformers()) if self.cache_key(idx, conf_idx) in self.cache: return self.mask_sel(self.cache[self.cache_key(idx, conf_idx)]) f_vect = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, **self.feat_vert_args) DATA_N = f_vect.shape[0] vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32) vect_feat[:DATA_N] = f_vect f_mat = molecule_features.feat_tensor_mol(mol, conf_idx=conf_idx, **self.feat_edge_args) if self.combine_mat_vect: MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1] else: MAT_CHAN = f_mat.shape[2] if MAT_CHAN == 0: # Dataloader can't handle tensors with empty dimensions MAT_CHAN = 1 mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN), dtype=np.float32) # do the padding mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat if self.combine_mat_vect == 'row': # row-major for i in range(DATA_N): mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect elif self.combine_mat_vect == 'col': # col-major for i in range(DATA_N): mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args) adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N)) adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad # create vect_mask and preds vect_mask = np.zeros((self.MAX_N, self.PRED_N), dtype=np.float32) vect_vals = np.zeros((self.MAX_N, self.PRED_N), dtype=np.float32) #print(self.PRED_N, pred_val) for pn in range(self.PRED_N): for k, v in vect_pred_val[pn].items(): vect_mask[int(k), pn] = 1.0 vect_vals[int(k), pn] = v # create matrix mask and preds mat_mask = np.zeros((self.MAX_N, self.MAX_N, self.PRED_N), dtype=np.float32) mat_vals = np.zeros((self.MAX_N, self.MAX_N, self.PRED_N), dtype=np.float32) #print(self.PRED_N, pred_val) for pn in range(self.PRED_N): for (k1, k2), v in mat_pred_val[pn].items(): mat_mask[int(k1), int(k2), pn] = 1.0 mat_vals[int(k1), int(k2), pn] = v mat_mask[int(k2), int(k1), pn] = 1.0 mat_vals[int(k2), int(k1), pn] = v # ADJ should be (N, N, features) adj = np.transpose(adj, axes=(1, 2, 0)) v = (adj, vect_feat, mat_feat, vect_vals, vect_mask, mat_vals, mat_mask) self.cache[self.cache_key(idx, conf_idx)] = v return self.mask_sel(v)
def __getitem__(self, idx): if self.frac_per_epoch < 1.0: # randomly get an index each time idx = np.random.randint(len(self.mols)) mol = self.mols[idx] pred_val = self.pred_vals[idx] whole_record = self.whole_records[idx] conf_idx = 0 if self.cache is not None and self.cache_key(idx, conf_idx) in self.cache: return self.cache[self.cache_key(idx, conf_idx)] # mol features f_mol = molecule_features.whole_molecule_features( whole_record, **self.mol_args) f_vect = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, **self.feat_vert_args) if self.combine_mol_vect: f_vect = torch.cat( [f_vect, f_mol.reshape(1, -1).expand(f_vect.shape[0], -1)], -1) # process extra data arguments for extra_data_config in self.extra_npy_filenames: filename = extra_data_config['filenames'][idx] combine_with = extra_data_config.get('combine_with', None) if combine_with == 'vert': npy_data = np.load(filename) npy_data_flatter = npy_data.reshape(f_vect.shape[0], -1) f_vect = torch.cat( [f_vect, torch.Tensor(npy_data_flatter)], dim=-1) elif combine_with is None: continue else: raise NotImplementedError( f"the combinewith {combine_with} not working yet") DATA_N = f_vect.shape[0] vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32) vect_feat[:DATA_N] = f_vect f_mat = molecule_features.feat_tensor_mol(mol, conf_idx=conf_idx, **self.feat_edge_args) if self.combine_mat_vect: MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1] else: MAT_CHAN = f_mat.shape[2] if MAT_CHAN == 0: # Dataloader can't handle tensors with empty dimensions MAT_CHAN = 1 mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN), dtype=np.float32) # do the padding mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat if self.combine_mat_vect == 'row': # row-major for i in range(DATA_N): mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect elif self.combine_mat_vect == 'col': # col-major for i in range(DATA_N): mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args) adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N)) adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad if self.combine_mat_feat_adj: adj = torch.cat([adj, torch.Tensor(mat_feat).permute(2, 0, 1)], 0) ### Simple one-hot encoding for reconstruction adj_oh_nopad = molecule_features.feat_mol_adj( mol, split_weights=[1.0, 1.5, 2.0, 3.0], edge_weighted=False, norm_adj=False, add_identity=False) adj_oh = torch.zeros((adj_oh_nopad.shape[0], self.MAX_N, self.MAX_N)) adj_oh[:, :adj_oh_nopad.shape[1], :adj_oh_nopad. shape[2]] = adj_oh_nopad ## per-edge features feat_edge_dict = edge_features.feat_edges(mol, ) # pad each of these edge_edge_nopad = feat_edge_dict['edge_edge'] edge_edge = torch.zeros( (edge_edge_nopad.shape[0], self.MAX_N, self.MAX_N)) # edge_edge[:, :edge_edge_nopad.shape[1], # :edge_edge_nopad.shape[2]] = torch.Tensor(edge_edge_nopad) edge_feat_nopad = feat_edge_dict['edge_feat'] edge_feat = torch.zeros((self.MAX_N, edge_feat_nopad.shape[1])) # edge_feat[:edge_feat_nopad.shape[0]] = torch.Tensor(edge_feat_nopad) edge_vert_nopad = feat_edge_dict['edge_vert'] edge_vert = torch.zeros( (edge_vert_nopad.shape[0], self.MAX_N, self.MAX_N)) # edge_vert[:, :edge_vert_nopad.shape[1], # :edge_vert_nopad.shape[2]] = torch.Tensor(edge_vert_nopad) atomicnos, coords = util.get_nos_coords(mol, conf_idx) coords_t = torch.zeros((self.MAX_N, 3)) coords_t[:len(coords), :] = torch.Tensor(coords) # create mask and preds pred_mask = np.zeros((self.MAX_N, self.PRED_N), dtype=np.float32) vals = np.ones((self.MAX_N, self.PRED_N), dtype=np.float32) * util.PERM_MISSING_VALUE #print(self.PRED_N, pred_val) if self.spect_assign: for pn in range(self.PRED_N): if len(pred_val) > 0: # when empty, there's nothing to predict atom_idx = [int(k) for k in pred_val[pn].keys()] obs_vals = [pred_val[pn][i] for i in atom_idx] # if self.shuffle_observations: # obs_vals = np.random.permutation(obs_vals) for k, v in zip(atom_idx, obs_vals): pred_mask[k, pn] = 1.0 vals[k, pn] = v else: if self.PRED_N > 1: raise NotImplementedError() vals[:] = util.PERM_MISSING_VALUE # sentinel value for k in pred_val[0][0]: pred_mask[k, 0] = 1 for vi, v in enumerate(pred_val[0][1]): vals[vi] = v # input mask input_mask = torch.zeros(self.MAX_N) input_mask[:DATA_N] = 1.0 v = { 'adj': adj, 'vect_feat': vect_feat, 'mat_feat': mat_feat, 'mol_feat': f_mol, 'vals': vals, 'adj_oh': adj_oh, 'pred_mask': pred_mask, 'coords': coords_t, 'input_mask': input_mask, 'input_idx': idx, 'edge_edge': edge_edge, 'edge_vert': edge_vert, 'edge_feat': edge_feat } ## add on extra args for ei, extra_data_config in enumerate(self.extra_npy_filenames): filename = extra_data_config['filenames'][idx] combine_with = extra_data_config.get('combine_with', None) if combine_with is None: ## this is an extra arg npy_data = np.load(filename) ## Zero pad npy_shape = list(npy_data.shape) npy_shape[0] = self.MAX_N t_pad = torch.zeros(npy_shape) t_pad[:npy_data.shape[0]] = torch.Tensor(npy_data) v[f'extra_data_{ei}'] = t_pad for k, kv in v.items(): assert np.isfinite(kv).all() if self.cache is not None: self.cache[self.cache_key(idx, conf_idx)] = v return v