def __init__(self, vocab, hidden_size, latent_size, depth): super(JTPropVAE, self).__init__() self.vocab = vocab self.hidden_size = hidden_size self.latent_size = latent_size self.depth = depth self.embedding = nn.Embedding(vocab.size(), hidden_size) self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding) self.jtmpn = JTMPN(hidden_size, depth) self.mpn = MPN(hidden_size, depth) self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2, self.embedding) self.T_mean = nn.Linear(hidden_size, latent_size / 2) self.T_var = nn.Linear(hidden_size, latent_size / 2) self.G_mean = nn.Linear(hidden_size, latent_size / 2) self.G_var = nn.Linear(hidden_size, latent_size / 2) self.propNN = nn.Sequential( nn.Linear(self.latent_size, self.hidden_size), nn.Tanh(), nn.Linear(self.hidden_size, 1)) self.prop_loss = nn.MSELoss() self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
def __init__(self, vocab, hidden_size, latent_size, depthT, depthG): super(JTNNVAEMLP, self).__init__() self.vocab = vocab self.hidden_size = hidden_size self.latent_size = latent_size = latent_size / 2 #Tree and Mol has two vectors self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size)) self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size)) self.jtmpn = JTMPN(hidden_size, depthG) self.mpn = MPN(hidden_size, depthG) self.A_assm = nn.Linear(latent_size, hidden_size, bias=False) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.T_mean = nn.Linear(hidden_size, latent_size) self.T_var = nn.Linear(hidden_size, latent_size) self.G_mean = nn.Linear(hidden_size, latent_size) self.G_var = nn.Linear(hidden_size, latent_size) self.T_hat_mean = nn.Linear(latent_size, latent_size) self.T_hat_var = nn.Linear(latent_size, latent_size) #New MLP self.gene_exp_size = 978 self.gene_mlp = nn.Linear(self.gene_exp_size, latent_size) self.tree_mlp = nn.Linear(latent_size+latent_size, latent_size)
def __init__(self, vocab, args): super(DiffVAE, self).__init__() self.vocab = vocab self.hidden_size = hidden_size = args.hidden_size self.rand_size = rand_size = args.rand_size self.jtmpn = JTMPN(hidden_size, args.depthG) self.mpn = MPN(hidden_size, args.depthG) if args.share_embedding: self.embedding = nn.Embedding(vocab.size(), hidden_size) self.jtnn = JTNNEncoder(hidden_size, args.depthT, self.embedding) self.decoder = JTNNDecoder(vocab, hidden_size, self.embedding, args.use_molatt) else: self.jtnn = JTNNEncoder(hidden_size, args.depthT, nn.Embedding(vocab.size(), hidden_size)) self.decoder = JTNNDecoder(vocab, hidden_size, nn.Embedding(vocab.size(), hidden_size), args.use_molatt) self.A_assm = nn.Linear(hidden_size, hidden_size, bias=False) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.T_mean = nn.Linear(hidden_size, rand_size / 2) self.T_var = nn.Linear(hidden_size, rand_size / 2) self.G_mean = nn.Linear(hidden_size, rand_size / 2) self.G_var = nn.Linear(hidden_size, rand_size / 2) self.B_t = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU()) self.B_g = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
def __init__(self, vocab, hidden_size, latent_size, depthT, depthG, loss_type='cos'): super(JTNNMJ, self).__init__() self.vocab = vocab self.hidden_size = hidden_size self.latent_size = latent_size = latent_size / 2 #Tree and Mol has two vectors self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size)) self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size)) self.jtmpn = JTMPN(hidden_size, depthG) self.mpn = MPN(hidden_size, depthG) self.A_assm = nn.Linear(latent_size, hidden_size, bias=False) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.T_mean = nn.Linear(hidden_size, latent_size) self.T_var = nn.Linear(hidden_size, latent_size) self.G_mean = nn.Linear(hidden_size, latent_size) self.G_var = nn.Linear(hidden_size, latent_size) #For MJ self.gene_exp_size = 978 self.gene_mlp = nn.Linear(self.gene_exp_size, 2 * hidden_size) self.gene_mlp2 = nn.Linear(2 * hidden_size, 2 * hidden_size) self.cos = nn.CosineSimilarity() self.loss_type = loss_type if loss_type == 'L1': self.cos_loss = torch.nn.L1Loss(reduction='elementwise_mean') elif loss_type == 'L2': self.cos_loss = torch.nn.MSELoss(reduction='elementwise_mean') elif loss_type == 'cos': self.cos_loss = torch.nn.CosineEmbeddingLoss()
def __init__(self, vocab, hidden_size, latent_size, depth): super(JTNNVAE, self).__init__() self.vocab = vocab self.hidden_size = int(hidden_size) self.latent_size = int(latent_size) self.depth = depth self.embedding = nn.Embedding(vocab.size(), hidden_size) self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding) self.jtmpn = JTMPN(hidden_size, depth) self.mpn = MPN(hidden_size, depth) self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2, self.embedding) self.T_mean = nn.Linear(hidden_size, int(latent_size / 2)) self.T_var = nn.Linear(hidden_size, int(latent_size / 2)) self.G_mean = nn.Linear(hidden_size, int(latent_size / 2)) self.G_var = nn.Linear(hidden_size, int(latent_size / 2)) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
class JTNNVAE(nn.Module): def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True): super(JTNNVAE, self).__init__() self.vocab = vocab self.hidden_size = hidden_size self.latent_size = latent_size self.depth = depth self.embedding = nn.Embedding(vocab.size(), hidden_size) self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding) self.jtmpn = JTMPN(hidden_size, depth) self.mpn = MPN(hidden_size, depth) self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2, self.embedding) self.T_mean = nn.Linear(hidden_size, latent_size / 2) self.T_var = nn.Linear(hidden_size, latent_size / 2) self.G_mean = nn.Linear(hidden_size, latent_size / 2) self.G_var = nn.Linear(hidden_size, latent_size / 2) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.use_stereo = stereo if stereo: self.stereo_loss = nn.CrossEntropyLoss(size_average=False) def encode(self, mol_batch): set_batch_nodeID(mol_batch, self.vocab) root_batch = [mol_tree.nodes[0] for mol_tree in mol_batch] tree_mess, tree_vec = self.jtnn(root_batch) smiles_batch = [mol_tree.smiles for mol_tree in mol_batch] mol_vec = self.mpn(mol2graph(smiles_batch)) return tree_mess, tree_vec, mol_vec def encode_latent_mean(self, smiles_list): mol_batch = [MolTree(s) for s in smiles_list] for mol_tree in mol_batch: mol_tree.recover() _, tree_vec, mol_vec = self.encode(mol_batch) tree_mean = self.T_mean(tree_vec) mol_mean = self.G_mean(mol_vec) return torch.cat([tree_mean, mol_mean], dim=1) def forward(self, mol_batch, beta=0): batch_size = len(mol_batch) tree_mess, tree_vec, mol_vec = self.encode(mol_batch) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. z_mean = torch.cat([tree_mean, mol_mean], dim=1) z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1) kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = create_var(torch.randn(batch_size, self.latent_size / 2), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(batch_size, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon word_loss, topo_loss, word_acc, topo_acc = self.decoder( mol_batch, tree_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess) if self.use_stereo: stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) else: stereo_loss, stereo_acc = 0, 0 all_vec = torch.cat([tree_vec, mol_vec], dim=1) loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc def assm(self, mol_batch, mol_vec, tree_mess): cands = [] batch_idx = [] for i, mol_tree in enumerate(mol_batch): for node in mol_tree.nodes: #Leaf node's attachment is determined by neighboring node's attachment if node.is_leaf or len(node.cands) == 1: continue cands.extend([(cand, mol_tree.nodes, node) for cand in node.cand_mols]) batch_idx.extend([i] * len(node.cands)) cand_vec = self.jtmpn(cands, tree_mess) cand_vec = self.G_mean(cand_vec) batch_idx = create_var(torch.LongTensor(batch_idx)) mol_vec = mol_vec.index_select(0, batch_idx) mol_vec = mol_vec.view(-1, 1, self.latent_size / 2) cand_vec = cand_vec.view(-1, self.latent_size / 2, 1) scores = torch.bmm(mol_vec, cand_vec).squeeze() cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch): comp_nodes = [ node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf ] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score[label].item() >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.assm_loss(cur_score.view(1, -1), label)) #all_loss = torch.stack(all_loss).sum() / len(mol_batch) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt def stereo(self, mol_batch, mol_vec): stereo_cands, batch_idx = [], [] labels = [] for i, mol_tree in enumerate(mol_batch): cands = mol_tree.stereo_cands if len(cands) == 1: continue if mol_tree.smiles3D not in cands: cands.append(mol_tree.smiles3D) stereo_cands.extend(cands) batch_idx.extend([i] * len(cands)) labels.append((cands.index(mol_tree.smiles3D), len(cands))) if len(labels) == 0: return create_var(torch.zeros(1)), 1.0 batch_idx = create_var(torch.LongTensor(batch_idx)) stereo_cands = self.mpn(mol2graph(stereo_cands)) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec.index_select(0, batch_idx) scores = torch.nn.CosineSimilarity()(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in labels: cur_scores = scores.narrow(0, st, le) if cur_scores.data[label] >= cur_scores.max(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label)) st += le #all_loss = torch.cat(all_loss).sum() / len(labels) all_loss = sum(all_loss) / len(labels) return all_loss, acc * 1.0 / len(labels) def reconstruct(self, smiles, prob_decode=False): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. epsilon = create_var(torch.randn(1, self.latent_size / 2), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon return self.decode(tree_vec, mol_vec, prob_decode) def recon_eval(self, smiles): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs( self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs( self.G_var(mol_vec)) #Following Mueller et al. all_smiles = [] for i in xrange(10): epsilon = create_var(torch.randn(1, self.latent_size / 2), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon for j in xrange(10): new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(new_smiles) return all_smiles def sample_prior(self, prob_decode=False): tree_vec = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = create_var(torch.randn(1, self.latent_size / 2), False) return self.decode(tree_vec, mol_vec, prob_decode) def sample_eval(self): tree_vec = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = create_var(torch.randn(1, self.latent_size / 2), False) all_smiles = [] for i in xrange(100): s = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(s) return all_smiles def decode(self, tree_vec, mol_vec, prob_decode): pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode) #Mark nid & is_leaf & atommap for i, node in enumerate(pred_nodes): node.nid = i + 1 node.is_leaf = (len(node.neighbors) == 1) if len(node.neighbors) > 1: set_atommap(node.mol, node.nid) tree_mess = self.jtnn([pred_root])[0] cur_mol = copy_edit_mol(pred_root.mol) global_amap = [{}] + [{} for node in pred_nodes] global_amap[1] = { atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms() } cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode) if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) if cur_mol is None: return None if self.use_stereo == False: return Chem.MolToSmiles(cur_mol) smiles2D = Chem.MolToSmiles(cur_mol) stereo_cands = decode_stereo(smiles2D) if len(stereo_cands) == 1: return stereo_cands[0] stereo_vecs = self.mpn(mol2graph(stereo_cands)) stereo_vecs = self.G_mean(stereo_vecs) scores = nn.CosineSimilarity()(stereo_vecs, mol_vec) _, max_id = scores.max(dim=0) return stereo_cands[max_id.data[0].item()] def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode): fa_nid = fa_node.nid if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] neighbors = sorted(neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node.nid] cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0: return None cand_smiles, cand_mols, cand_amap = zip(*cands) cands = [(candmol, all_nodes, cur_node) for candmol in cand_mols] cand_vecs = self.jtmpn(cands, tree_mess) cand_vecs = self.G_mean(cand_vecs) mol_vec = mol_vec.squeeze() scores = torch.mv(cand_vecs, mol_vec) * 20 if prob_decode: probs = nn.Softmax()(scores.view( 1, -1)).squeeze() + 1e-5 #prevent prob = 0 cand_idx = torch.multinomial(probs, probs.numel()) else: _, cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) for i in xrange(cand_idx.numel()): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id, ctr_atom, nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[ cur_node.nid][ctr_atom] cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue result = True for nei_node in children: if nei_node.is_leaf: continue cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode) if cur_mol is None: result = False break if result: return cur_mol return None
class JTNNVAE(nn.Module): def __init__(self, vocab, hidden_size, latent_size, depthT, depthG): super(JTNNVAE, self).__init__() self.vocab = vocab self.hidden_size = hidden_size self.latent_size = latent_size = latent_size // 2 #Tree and Mol has two vectors self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size)) self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size)) self.jtmpn = JTMPN(hidden_size, depthG) self.mpn = MPN(hidden_size, depthG) self.A_assm = nn.Linear(latent_size, hidden_size, bias=False) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.T_mean = nn.Linear(hidden_size, latent_size) self.T_var = nn.Linear(hidden_size, latent_size) self.G_mean = nn.Linear(hidden_size, latent_size) self.G_var = nn.Linear(hidden_size, latent_size) def encode(self, jtenc_holder, mpn_holder): tree_vecs, tree_mess = self.jtnn(*jtenc_holder) mol_vecs = self.mpn(*mpn_holder) return tree_vecs, tree_mess, mol_vecs def encode_from_smiles(self, smiles_list): tree_batch = [MolTree(s) for s in smiles_list] _, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False) tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder) return torch.cat([tree_vecs, mol_vecs], dim=-1) def encode_latent(self, jtenc_holder, mpn_holder): tree_vecs, _ = self.jtnn(*jtenc_holder) mol_vecs = self.mpn(*mpn_holder) tree_mean = self.T_mean(tree_vecs) mol_mean = self.G_mean(mol_vecs) tree_var = -torch.abs(self.T_var(tree_vecs)) mol_var = -torch.abs(self.G_var(mol_vecs)) return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1) def rsample(self, z_vecs, W_mean, W_var): batch_size = z_vecs.size(0) z_mean = W_mean(z_vecs) z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al. kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = create_var(torch.randn_like(z_mean)) z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon return z_vecs, kl_loss def sample_prior(self, prob_decode=False): z_tree = torch.randn(1, self.latent_size).cuda() z_mol = torch.randn(1, self.latent_size).cuda() return self.decode(z_tree, z_mol, prob_decode) def forward(self, x_batch, beta): x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder = x_batch x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder) z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var) z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var) kl_div = tree_kl + mol_kl word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs) assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess) return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess): jtmpn_holder,batch_idx = jtmpn_holder fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder batch_idx = create_var(batch_idx) cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess) x_mol_vecs = x_mol_vecs.index_select(0, batch_idx) x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear scores = torch.bmm( x_mol_vecs.unsqueeze(1), cand_vecs.unsqueeze(-1) ).squeeze() cnt,tot,acc = 0,0,0 all_loss = [] for i,mol_tree in enumerate(mol_batch): comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score.data[label] >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append( self.assm_loss(cur_score.view(1,-1), label) ) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt def decode(self, x_tree_vecs, x_mol_vecs, prob_decode): #currently do not support batch decoding assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1 pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode) if len(pred_nodes) == 0: return None elif len(pred_nodes) == 1: return pred_root.smiles #Mark nid & is_leaf & atommap for i,node in enumerate(pred_nodes): node.nid = i + 1 node.is_leaf = (len(node.neighbors) == 1) if len(node.neighbors) > 1: set_atommap(node.mol, node.nid) scope = [(0, len(pred_nodes))] jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope) _,tree_mess = self.jtnn(*jtenc_holder) tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear cur_mol = copy_edit_mol(pred_root.mol) global_amap = [{}] + [{} for node in pred_nodes] global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True) if cur_mol is None: cur_mol = copy_edit_mol(pred_root.mol) global_amap = [{}] + [{} for node in pred_nodes] global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False) if cur_mol is None: cur_mol = pre_mol if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma): fa_nid = fa_node.nid if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid] cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma): return None, cur_mol cand_smiles,cand_amap = zip(*cands) aroma_score = torch.Tensor(aroma_score).cuda() cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles] if len(cands) > 1: jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1]) fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0]) scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score else: scores = torch.Tensor([1.0]) if prob_decode: probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0 cand_idx = torch.multinomial(probs, probs.numel()) else: _,cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) pre_mol = cur_mol for i in range(cand_idx.numel()): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id,ctr_atom,nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom] cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue has_error = False for nei_node in children: if nei_node.is_leaf: continue tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma) if tmp_mol is None: has_error = True if i == 0: pre_mol = tmp_mol2 break cur_mol = tmp_mol if not has_error: return cur_mol, cur_mol return None, pre_mol
class DiffVAE(nn.Module): def __init__(self, vocab, args): super(DiffVAE, self).__init__() self.vocab = vocab self.hidden_size = hidden_size = args.hidden_size self.rand_size = rand_size = args.rand_size self.jtmpn = JTMPN(hidden_size, args.depthG) self.mpn = MPN(hidden_size, args.depthG) if args.share_embedding: self.embedding = nn.Embedding(vocab.size(), hidden_size) self.jtnn = JTNNEncoder(hidden_size, args.depthT, self.embedding) self.decoder = JTNNDecoder(vocab, hidden_size, self.embedding, args.use_molatt) else: self.jtnn = JTNNEncoder(hidden_size, args.depthT, nn.Embedding(vocab.size(), hidden_size)) self.decoder = JTNNDecoder(vocab, hidden_size, nn.Embedding(vocab.size(), hidden_size), args.use_molatt) self.A_assm = nn.Linear(hidden_size, hidden_size, bias=False) self.assm_loss = nn.CrossEntropyLoss(size_average=False) self.T_mean = nn.Linear(hidden_size, rand_size / 2) self.T_var = nn.Linear(hidden_size, rand_size / 2) self.G_mean = nn.Linear(hidden_size, rand_size / 2) self.G_var = nn.Linear(hidden_size, rand_size / 2) self.B_t = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU()) self.B_g = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU()) def encode(self, jtenc_holder, mpn_holder): tree_vecs, tree_mess = self.jtnn(*jtenc_holder) mol_vecs = self.mpn(*mpn_holder) return tree_vecs, tree_mess, mol_vecs def fuse_noise(self, tree_vecs, mol_vecs): tree_eps = create_var( torch.randn(tree_vecs.size(0), 1, self.rand_size / 2) ) tree_eps = tree_eps.expand(-1, tree_vecs.size(1), -1) mol_eps = create_var( torch.randn(mol_vecs.size(0), 1, self.rand_size / 2) ) mol_eps = mol_eps.expand(-1, mol_vecs.size(1), -1) tree_vecs = torch.cat([tree_vecs,tree_eps], dim=-1) mol_vecs = torch.cat([mol_vecs,mol_eps], dim=-1) return self.B_t(tree_vecs), self.B_g(mol_vecs) def fuse_pair(self, x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, jtenc_scope, mpn_scope): diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1) size = create_var(torch.Tensor([le for _,le in jtenc_scope])) diff_tree_vecs = diff_tree_vecs / size.unsqueeze(-1) diff_mol_vecs = y_mol_vecs.sum(dim=1) - x_mol_vecs.sum(dim=1) size = create_var(torch.Tensor([le for _,le in mpn_scope])) diff_mol_vecs = diff_mol_vecs / size.unsqueeze(-1) diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var) diff_mol_vecs, mol_kl = self.rsample(diff_mol_vecs, self.G_mean, self.G_var) diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1) diff_mol_vecs = diff_mol_vecs.unsqueeze(1).expand(-1, x_mol_vecs.size(1), -1) x_tree_vecs = torch.cat([x_tree_vecs,diff_tree_vecs], dim=-1) x_mol_vecs = torch.cat([x_mol_vecs,diff_mol_vecs], dim=-1) return self.B_t(x_tree_vecs), self.B_g(x_mol_vecs), tree_kl + mol_kl def rsample(self, z_vecs, W_mean, W_var): z_mean = W_mean(z_vecs) z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al. kl_loss = -0.5 * torch.mean(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) epsilon = create_var(torch.randn_like(z_mean)) z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon return z_vecs, kl_loss def forward(self, x_batch, y_batch, beta): x_batch, x_jtenc_holder, x_mpn_holder = x_batch y_batch, y_jtenc_holder, y_mpn_holder, y_jtmpn_holder = y_batch x_tree_vecs, _, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder) y_tree_vecs, y_tree_mess, y_mol_vecs = self.encode(y_jtenc_holder, y_mpn_holder) x_tree_vecs, x_mol_vecs, kl_div = self.fuse_pair(x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, y_jtenc_holder[-1], y_mpn_holder[-1]) word_loss, topo_loss, word_acc, topo_acc = self.decoder(y_batch, x_tree_vecs, x_mol_vecs) assm_loss, assm_acc = self.assm(y_batch, y_jtmpn_holder, x_mol_vecs, y_tree_mess) return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, y_tree_mess): jtmpn_holder,batch_idx = jtmpn_holder fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder batch_idx = create_var(batch_idx) cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess) x_mol_vecs = x_mol_vecs.sum(dim=1) #average pooling? x_mol_vecs = x_mol_vecs.index_select(0, batch_idx) x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear scores = torch.bmm( x_mol_vecs.unsqueeze(1), cand_vecs.unsqueeze(-1) ).squeeze() cnt,tot,acc = 0,0,0 all_loss = [] for i,mol_tree in enumerate(mol_batch): comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score.data[label] >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append( self.assm_loss(cur_score.view(1,-1), label) ) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt def decode(self, x_tree_vecs, x_mol_vecs): #currently do not support batch decoding assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1 pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, x_mol_vecs) if len(pred_nodes) == 0: return None elif len(pred_nodes) == 1: return pred_root.smiles #Mark nid & is_leaf & atommap for i,node in enumerate(pred_nodes): node.nid = i + 1 node.is_leaf = (len(node.neighbors) == 1) if len(node.neighbors) > 1: set_atommap(node.mol, node.nid) scope = [(0, len(pred_nodes))] jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope) _,tree_mess = self.jtnn(*jtenc_holder) tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict x_mol_vec_pooled = x_mol_vecs.sum(dim=1) #average pooling? x_mol_vec_pooled = self.A_assm(x_mol_vec_pooled).squeeze() #bilinear cur_mol = copy_edit_mol(pred_root.mol) global_amap = [{}] + [{} for node in pred_nodes] global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} cur_mol = self.dfs_assemble(tree_mess, x_mol_vec_pooled, pred_nodes, cur_mol, global_amap, [], pred_root, None) if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None def dfs_assemble(self, y_tree_mess, x_mol_vec_pooled, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node): fa_nid = fa_node.nid if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid] cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0: return None cand_smiles,cand_amap = zip(*cands) cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles] jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1]) fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0]) scores = torch.mv(cand_vecs, x_mol_vec_pooled) _,cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) #for i in xrange(cand_idx.numel()): for i in xrange( min(cand_idx.numel(), 5) ): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id,ctr_atom,nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom] cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue result = True for nei_node in children: if nei_node.is_leaf: continue cur_mol = self.dfs_assemble(y_tree_mess, x_mol_vec_pooled, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node) if cur_mol is None: result = False break if result: return cur_mol return None