def forward(self, mol_batch, beta=0): batch_size = len(mol_batch) tree_mess, tree_vec, mol_vec = self.encode(mol_batch) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. z_mean = torch.cat([tree_mean, mol_mean], dim=1) z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1) kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = create_var( torch.randn(batch_size, int(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var( torch.randn(batch_size, int(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon word_loss, topo_loss, word_acc, topo_acc = self.decoder( mol_batch, tree_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) all_vec = torch.cat([tree_vec, mol_vec], dim=1) loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc
def stereo(self, mol_batch, mol_vec): stereo_cands, batch_idx = [], [] labels = [] for i, mol_tree in enumerate(mol_batch): cands = mol_tree.stereo_cands if len(cands) == 1: continue if mol_tree.smiles3D not in cands: cands.append(mol_tree.smiles3D) stereo_cands.extend(cands) batch_idx.extend([i] * len(cands)) labels.append((cands.index(mol_tree.smiles3D), len(cands))) if len(labels) == 0: return create_var(torch.zeros(1)), 1.0 batch_idx = create_var(torch.LongTensor(batch_idx)) stereo_cands = self.mpn(mol2graph(stereo_cands)) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec.index_select(0, batch_idx) scores = torch.nn.CosineSimilarity()(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in labels: cur_scores = scores.narrow(0, st, le) if cur_scores.data[label] >= cur_scores.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label)) st += le #all_loss = torch.cat(all_loss).sum() / len(labels) all_loss = sum(all_loss) / len(labels) return all_loss, acc * 1.0 / len(labels)
def sample_eval(self): tree_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False) mol_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False) all_smiles = [] for i in range(100): s = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(s) return all_smiles
def forward(self, root_batch): orders = [] for root in root_batch: order = get_prop_order(root) orders.append(order) h = {} max_depth = max([len(x) for x in orders]) padding = create_var(torch.zeros(self.hidden_size), False) for t in range(max_depth): prop_list = [] for order in orders: if t < len(order): prop_list.extend(order[t]) cur_x = [] cur_h_nei = [] for node_x, node_y in prop_list: x, y = node_x.idx, node_y.idx cur_x.append(node_x.wid) h_nei = [] for node_z in node_x.neighbors: z = node_z.idx if z == y: continue h_nei.append(h[(z, x)]) pad_len = MAX_NB - len(h_nei) h_nei.extend([padding] * pad_len) cur_h_nei.extend(h_nei) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) for i, m in enumerate(prop_list): x, y = m[0].idx, m[1].idx h[(x, y)] = new_h[i] root_vecs = node_aggregate(root_batch, h, self.embedding, self.W) return h, root_vecs
def reconstruct(self, smiles, prob_decode=False): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon return self.decode(tree_vec, mol_vec, prob_decode)
def assm(self, mol_batch, mol_vec, tree_mess): cands = [] batch_idx = [] for i, mol_tree in enumerate(mol_batch): for node in mol_tree.nodes: #Leaf node's attachment is determined by neighboring node's attachment if node.is_leaf or len(node.cands) == 1: continue cands.extend([(cand, mol_tree.nodes, node) for cand in node.cand_mols]) batch_idx.extend([i] * len(node.cands)) cand_vec = self.jtmpn(cands, tree_mess) cand_vec = self.G_mean(cand_vec) batch_idx = create_var(torch.LongTensor(batch_idx)) mol_vec = mol_vec.index_select(0, batch_idx) mol_vec = mol_vec.view(-1, 1, int(self.latent_size / 2)) cand_vec = cand_vec.view(-1, int(self.latent_size / 2), 1) scores = torch.bmm(mol_vec, cand_vec).squeeze() cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch): comp_nodes = [ node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf ] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score.data[label] >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.assm_loss(cur_score.view(1, -1), label)) #all_loss = torch.stack(all_loss).sum() / len(mol_batch) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt
def node_aggregate(nodes, h, embedding, W): x_idx = [] h_nei = [] hidden_size = embedding.embedding_dim padding = create_var(torch.zeros(hidden_size), False) for node_x in nodes: x_idx.append(node_x.wid) nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(nei) nei.extend([padding] * pad_len) h_nei.extend(nei) h_nei = torch.cat(h_nei, dim=0).view(-1, MAX_NB, hidden_size) sum_h_nei = h_nei.sum(dim=1) x_vec = create_var(torch.LongTensor(x_idx)) x_vec = embedding(x_vec) node_vec = torch.cat([x_vec, sum_h_nei], dim=1) return nn.ReLU()(W(node_vec))
def recon_eval(self, smiles): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. all_smiles = [] for i in range(10): epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon for j in range(10): new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(new_smiles) return all_smiles
def forward(self, cand_batch, tree_mess): fatoms,fbonds = [],[] in_bonds,all_bonds = [],[] mess_dict,all_mess = {},[create_var(torch.zeros(self.hidden_size))] #Ensure index 0 is vec(0) total_atoms = 0 scope = [] for e,vec in tree_mess.items(): mess_dict[e] = len(all_mess) all_mess.append(vec) for mol,all_nodes,ctr_node in cand_batch: n_atoms = mol.GetNumAtoms() ctr_bid = ctr_node.idx for atom in mol.GetAtoms(): fatoms.append( atom_features(atom) ) in_bonds.append([]) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom() a2 = bond.GetEndAtom() x = a1.GetIdx() + total_atoms y = a2.GetIdx() + total_atoms #Here x_nid,y_nid could be 0 x_nid,y_nid = a1.GetAtomMapNum(),a2.GetAtomMapNum() x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1 y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1 bfeature = bond_features(bond) b = len(all_mess) + len(all_bonds) #bond idx offseted by len(all_mess) all_bonds.append((x,y)) fbonds.append( torch.cat([fatoms[x], bfeature], 0) ) in_bonds[y].append(b) b = len(all_mess) + len(all_bonds) all_bonds.append((y,x)) fbonds.append( torch.cat([fatoms[y], bfeature], 0) ) in_bonds[x].append(b) if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if (x_bid,y_bid) in mess_dict: mess_idx = mess_dict[(x_bid,y_bid)] in_bonds[y].append(mess_idx) if (y_bid,x_bid) in mess_dict: mess_idx = mess_dict[(y_bid,x_bid)] in_bonds[x].append(mess_idx) scope.append((total_atoms,n_atoms)) total_atoms += n_atoms total_bonds = len(all_bonds) total_mess = len(all_mess) fatoms = torch.stack(fatoms, 0) fbonds = torch.stack(fbonds, 0) agraph = torch.zeros(total_atoms,MAX_NB).long() bgraph = torch.zeros(total_bonds,MAX_NB).long() tree_message = torch.stack(all_mess, dim=0) for a in range(total_atoms): for i,b in enumerate(in_bonds[a]): agraph[a,i] = b for b1 in range(total_bonds): x,y = all_bonds[b1] for i,b2 in enumerate(in_bonds[x]): #b2 is offseted by len(all_mess) if b2 < total_mess or all_bonds[b2-total_mess][0] != y: bgraph[b1,i] = b2 fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) graph_message = nn.ReLU()(binput) for i in range(self.depth - 1): message = torch.cat([tree_message,graph_message], dim=0) nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum(dim=1) nei_message = self.W_h(nei_message) graph_message = nn.ReLU()(binput + nei_message) message = torch.cat([tree_message,graph_message], dim=0) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = nn.ReLU()(self.W_o(ainput)) mol_vecs = [] for st,le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def forward(self, mol_batch, mol_vec): super_root = MolTreeNode("") super_root.idx = -1 #Initialize pred_hiddens,pred_mol_vecs,pred_targets = [],[],[] stop_hiddens,stop_targets = [],[] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], super_root) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] #Predict Root pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_mol_vecs.append(mol_vec) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) h = {} for t in range(max_iter): prop_list = [] batch_list = [] for i,plist in enumerate(traces): if t < len(plist): prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei,cur_o_nei = [],[] for node_x,real_y,_ in prop_list: #Neighbors for message passing (target not included) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx] pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) #Neighbors for stop prediction (all neighbors) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) #Current clique embedding cur_x.append(node_x.wid) #Clique embedding cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #Message passing cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) #Node Aggregate cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size) cur_o = cur_o_nei.sum(dim=1) #Gather targets pred_target,pred_list = [],[] stop_target = [] for i,m in enumerate(prop_list): node_x,node_y,direction = m x,y = node_x.idx,node_y.idx h[(x,y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) #Hidden states for stop prediction cur_batch = create_var(torch.LongTensor(batch_list)) cur_mol_vec = mol_vec.index_select(0, cur_batch) stop_hidden = torch.cat([cur_x,cur_o,cur_mol_vec], dim=1) stop_hiddens.append( stop_hidden ) stop_targets.extend( stop_target ) #Hidden states for clique prediction if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_mol_vecs.append( mol_vec.index_select(0, cur_batch) ) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append( new_h.index_select(0, cur_pred) ) pred_targets.extend( pred_target ) #Last stop at root cur_x,cur_o_nei = [],[] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x,cur_o,mol_vec], dim=1) stop_hiddens.append( stop_hidden ) stop_targets.extend( [0] * len(mol_batch) ) #Predict next clique pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0) pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1) pred_vecs = nn.ReLU()(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) _,preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() #Predict stop stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_vecs = nn.ReLU()(self.U(stop_hiddens)) stop_scores = self.U_s(stop_vecs).squeeze() stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() # return pred_loss, stop_loss, pred_acc.data[0], stop_acc.data[0] return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def decode(self, mol_vec, prob_decode): stack,trace = [],[] init_hidden = create_var(torch.zeros(1,self.hidden_size)) zero_pad = create_var(torch.zeros(1,1,self.hidden_size)) #Root Prediction root_hidden = torch.cat([init_hidden, mol_vec], dim=1) root_hidden = nn.ReLU()(self.W(root_hidden)) root_score = self.W_o(root_hidden) _,root_wid = torch.max(root_score, dim=1) root_wid = root_wid.data[0] root = MolTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append( (root, self.vocab.get_slots(root.wid)) ) all_nodes = [root] h = {} for step in range(MAX_DECODE_LEN): node_x,fa_slot = stack[-1] cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hidden = torch.cat([cur_x,cur_h,mol_vec], dim=1) stop_hidden = nn.ReLU()(self.U(stop_hidden)) stop_score = nn.Sigmoid()(self.U_s(stop_hidden) * 20).squeeze() if prob_decode: backtrack = (torch.bernoulli(1.0 - stop_score.data).item() == 1) else: backtrack = (stop_score.data[0] < 0.5) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_hidden = torch.cat([new_h,mol_vec], dim=1) pred_hidden = nn.ReLU()(self.W(pred_hidden)) pred_score = nn.Softmax(dim=1)(self.W_o(pred_hidden) * 20) if prob_decode: if(pred_score.data.squeeze().sum().item() > 1): print(pred_score.data.squeeze().sum().item()) sort_wid = torch.multinomial(pred_score.data.squeeze(), 5) #sort_wid = np.random.multinomial(5, pred_score.data.squeeze()) else: _,sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() sort_wid = sort_wid.cpu().numpy() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolTreeNode(self.vocab.get_smiles(wid)) if have_slots(fa_slot, slots) and can_assemble(node_x, node_y): next_wid = wid next_slots = slots break if next_wid is None: backtrack = True #No more children can be added else: node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = step + 1 node_y.neighbors.append(node_x) h[(node_x.idx,node_y.idx)] = new_h[0] stack.append( (node_y,next_slots) ) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa,_ = stack[-2] cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx,node_fa.idx)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes
def sample_prior(self, prob_decode=False): tree_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False) mol_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False) return self.decode(tree_vec, mol_vec, prob_decode)
def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) mol = Chem.MolFromSmiles(smiles) fp1 = AllChem.GetMorganFingerprint(mol, 2) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. mean = torch.cat([tree_mean, mol_mean], dim=1) log_var = torch.cat([tree_log_var, mol_log_var], dim=1) cur_vec = create_var(mean.data, True) visited = [] for step in range(num_iter): prop_val = self.propNN(cur_vec).squeeze() grad = torch.autograd.grad(prop_val, cur_vec)[0] cur_vec = cur_vec.data + lr * grad.data cur_vec = create_var(cur_vec, True) visited.append(cur_vec) l, r = 0, num_iter - 1 while l < r - 1: mid = int((l + r) / 2) new_vec = visited[mid] tree_vec, mol_vec = torch.chunk(new_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: r = mid - 1 continue new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim < sim_cutoff: r = mid - 1 else: l = mid """ best_vec = visited[0] for new_vec in visited: tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: continue new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim >= sim_cutoff: best_vec = new_vec """ tree_vec, mol_vec = torch.chunk(visited[l], 2, dim=1) #tree_vec,mol_vec = torch.chunk(best_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: return smiles, 1.0 new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim >= sim_cutoff: return new_smiles, sim else: return smiles, 1.0