def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) ################## # try: fnode = self.embedding(fnode) #print(fnode.size()) # except: # fnode = torch.randn((fnode.size(),hidden_size)).cuda() # #################### fmess = index_select_ND(fnode, 0, fmess) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _,x in scope]) batch_vecs = [] for st,le in scope: cur_vecs = node_vecs[st] #Root is the first node batch_vecs.append( cur_vecs ) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def stereo(self, mol_batch, mol_vec): stereo_cands, batch_idx = [], [] labels = [] for i, mol_tree in enumerate(mol_batch): cands = mol_tree.stereo_cands if len(cands) == 1: continue if mol_tree.smiles3D not in cands: cands.append(mol_tree.smiles3D) stereo_cands.extend(cands) batch_idx.extend([i] * len(cands)) labels.append((cands.index(mol_tree.smiles3D), len(cands))) if len(labels) == 0: return create_var(torch.zeros(1)), 1.0 batch_idx = create_var(torch.LongTensor(batch_idx)) stereo_cands = self.mpn(mol2graph(stereo_cands)) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec.index_select(0, batch_idx) scores = torch.nn.CosineSimilarity()(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in labels: cur_scores = scores.narrow(0, st, le) if cur_scores.data[label] >= cur_scores.max(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label)) st += le #all_loss = torch.cat(all_loss).sum() / len(labels) all_loss = sum(all_loss) / len(labels) return all_loss, acc * 1.0 / len(labels)
def assm(self, junc_tree_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess): jtmpn_holder, batch_idx = x_jtmpn_holder atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, bond_adjacency_graph, scope = jtmpn_holder batch_idx = create_var(batch_idx) candidate_vecs = self.jtmpn(atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, bond_adjacency_graph, scope, x_tree_mess) z_mol_vecs = z_mol_vecs.index_select(0, batch_idx) z_mol_vecs = self.A_assm(z_mol_vecs) # bilinear scores = torch.bmm( z_mol_vecs.unsqueeze(1), candidate_vecs.unsqueeze(-1) ).squeeze() cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(junc_tree_batch): comp_nodes = [node for node in mol_tree.nodes if len(node.candidates) > 1 and not node.is_leaf] cnt += len(comp_nodes) for node in comp_nodes: label = node.candidates.index(node.label) num_candidates = len(node.candidates) cur_score = scores.narrow(0, tot, num_candidates) tot += num_candidates if cur_score.data[label] >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.assm_loss(cur_score.view(1, -1), label)) all_loss = sum(all_loss) / len(junc_tree_batch) return all_loss, acc * 1.0 / cnt
def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) fnode = self.embedding(fnode) fmess1 = index_select_ND(fnode, 0, fmess[:, 0]) fmess2 = self.E_pos(fmess[:, 1]) fmess = self.inputNN( torch.cat([fmess1,fmess2], dim=-1) ) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _,x in scope]) batch_vecs = [] for st,le in scope: cur_vecs = node_vecs[st] #Root is the first node batch_vecs.append( cur_vecs ) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def forward(self, mol_batch, beta=0): batch_size = len(mol_batch) tree_mess, tree_vec, mol_vec = self.encode(mol_batch) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al. z_mean = torch.cat([tree_mean,mol_mean], dim=1) z_log_var = torch.cat([tree_log_var,mol_log_var], dim=1) kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = create_var(torch.randn(batch_size, (int)(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(batch_size, (int)(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess) if self.use_stereo: stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) else: stereo_loss, stereo_acc = 0, 0 all_vec = torch.cat([tree_vec, mol_vec], dim=1) loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess): jtmpn_holder,batch_idx = jtmpn_holder fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder batch_idx = create_var(batch_idx) cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess) x_mol_vecs = x_mol_vecs.index_select(0, batch_idx) x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear scores = torch.bmm( x_mol_vecs.unsqueeze(1), cand_vecs.unsqueeze(-1) ).squeeze() cnt,tot,acc = 0,0,0 all_loss = [] for i,mol_tree in enumerate(mol_batch): comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score.data[label] >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append( self.assm_loss(cur_score.view(1,-1), label) ) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt
def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0) fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) graph_message = F.relu(binput) for i in xrange(self.depth - 1): message = torch.cat([tree_message,graph_message], dim=0) nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0) nei_message = self.W_h(nei_message) graph_message = F.relu(binput + nei_message) message = torch.cat([tree_message,graph_message], dim=0) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = F.relu(self.W_o(ainput)) mol_vecs = [] for st,le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) fnode = self.embedding(fnode) fmess = index_select_ND(fnode, 0, fmess) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _, x in scope]) batch_vecs = [] for st, le in scope: cur_vecs = node_vecs[st:st + le] cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len - le)) batch_vecs.append(cur_vecs) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def forward(self, mol_batch, beta=0): batch_size = len(mol_batch) mol_batch, prop_batch = zip(*mol_batch) tree_mess, tree_vec, mol_vec = self.encode(mol_batch) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al. z_mean = torch.cat([tree_mean,mol_mean], dim=1) z_log_var = torch.cat([tree_log_var,mol_log_var], dim=1) kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = create_var(torch.randn(batch_size, self.latent_size / 2), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(batch_size, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) all_vec = torch.cat([tree_vec, mol_vec], dim=1) prop_label = create_var(torch.Tensor(prop_batch)) prop_loss = self.prop_loss(self.propNN(all_vec).squeeze(), prop_label) loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss + prop_loss return loss, kl_loss.data[0], word_acc, topo_acc, assm_acc, stereo_acc, prop_loss.data[0]
def reconstruct(self, smiles): junc_tree = MolJuncTree(smiles) junc_tree.recover() set_batch_nodeID([junc_tree], self.vocab) jtenc_holder, _ = JTNNEncoder.tensorize([junc_tree]) mpn_holder = MessPassNet.tensorize([smiles]) tree_vec, _, mol_vec = self.encode(jtenc_holder, mpn_holder) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) # Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) # Following Mueller et al. # epsilon = create_var(torch.randn(1, self.latent_size // 2), False) # tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon # epsilon = create_var(torch.randn(1, self.latent_size // 2), False) # mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon return self.decode(tree_vec, mol_vec)
def forward(self, mol_graph): fatoms, fbonds, agraph, bgraph, scope = mol_graph fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) message = nn.ReLU()(binput) for _ in xrange(self.depth - 1): nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum(dim=1) nei_message = self.W_h(nei_message) message = nn.ReLU()(binput + nei_message) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = nn.ReLU()(self.W_o(ainput)) mol_vecs = [] for st, le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def sample_eval(self): tree_vec = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = create_var(torch.randn(1, self.latent_size / 2), False) all_smiles = [] for i in xrange(100): s = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(s) return all_smiles
def fuse_noise(self, tree_vecs, mol_vecs): tree_eps = create_var( torch.randn(tree_vecs.size(0), 1, self.rand_size / 2) ) tree_eps = tree_eps.expand(-1, tree_vecs.size(1), -1) mol_eps = create_var( torch.randn(mol_vecs.size(0), 1, self.rand_size / 2) ) mol_eps = mol_eps.expand(-1, mol_vecs.size(1), -1) tree_vecs = torch.cat([tree_vecs,tree_eps], dim=-1) mol_vecs = torch.cat([mol_vecs,mol_eps], dim=-1) return self.B_t(tree_vecs), self.B_g(mol_vecs)
def sample_prior_eval(self, prob_decode=False, ns=1000, nd=500): priors = [] for i in range(ns): #100 dec = [] tree_vec = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = create_var(torch.randn(1, self.latent_size / 2), False) for j in range(nd): #500 dec.append(self.decode(tree_vec, mol_vec, prob_decode)) priors.append(dec) return priors
def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20): mol_tree = MolTree(smiles) mol_tree.recover() _, tree_vec, mol_vec = self.encode([mol_tree]) mol = Chem.MolFromSmiles(smiles) fp1 = AllChem.GetMorganFingerprint(mol, 2) tree_mean = self.T_mean(tree_vec) # Following Mueller et al. tree_log_var = -torch.abs(self.T_var(tree_vec)) mol_mean = self.G_mean(mol_vec) # Following Mueller et al. mol_log_var = -torch.abs(self.G_var(mol_vec)) mean = torch.cat([tree_mean, mol_mean], dim=1) log_var = torch.cat([tree_log_var, mol_log_var], dim=1) cur_vec = create_var(mean.data, True) visited = [] for _ in xrange(num_iter): prop_val = self.propNN(cur_vec).squeeze() grad = torch.autograd.grad(prop_val, cur_vec)[0] cur_vec = cur_vec.data + lr * grad.data cur_vec = create_var(cur_vec, True) visited.append(cur_vec) l, r = 0, num_iter - 1 while l < r - 1: mid = (l + r) / 2 new_vec = visited[mid] tree_vec, mol_vec = torch.chunk(new_vec, 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: r = mid - 1 continue new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim < sim_cutoff: r = mid - 1 else: l = mid tree_vec, mol_vec = torch.chunk(visited[l], 2, dim=1) new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False) if new_smiles is None: return smiles, 1.0 new_mol = Chem.MolFromSmiles(new_smiles) fp2 = AllChem.GetMorganFingerprint(new_mol, 2) sim = DataStructs.TanimotoSimilarity(fp1, fp2) if sim >= sim_cutoff: return new_smiles, sim else: return smiles, 1.0
def forward(self, root_batch): orders = [] # orders: list(list), 每个子列表代表一个根层次遍历的结果 for root in root_batch: # oder: list(list), 一个列表分为两部分, # 一个是自底向上的顺序,每个子列表中包含该层的结点及其父节点 # 一个是自顶向下的顺序,每个子列表中包含该层的结点及其子节点 order = get_prop_order(root) orders.append(order) h = {} max_depth = max([len(order) for order in orders]) padding = create_var(torch.zeros(self.hidden_size), False) for t in range(max_depth): prop_list = [] for order in orders: if len(order) > t: # 确保这棵树有第t层 prop_list.extend(order[t]) # 第t层的层次列表加入到prop_list cur_x = [] cur_h_nei = [] for node_x, node_y in prop_list: x, y = node_x.idx, node_y.idx # 结点编号 cur_x.append(node_x.wid) # 结点类型编号 h_nei = [] for node_z in node_x.neighbors: z = node_z.idx if z == y: continue # h_nei:结点x除y以外的邻居,即与其相邻的结点 h_nei.append(h[(z, x)]) # 如果邻居数量达不到最大值,则用padding的变量填充 pad_len = MAX_NB - len(h_nei) h_nei.extend([padding] * pad_len) cur_h_nei.extend(h_nei) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding( cur_x ) # 从这里开始,标签转化为了向量, cur_x.size = (len(prop_list), hidden_size) cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) # cur_nei_h.size = (len(prop_list), MAX_NB, hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) for i, m in enumerate(prop_list): x, y = m[0].idx, m[1].idx h[(x, y)] = new_h[i] # node aggregate root_vecs = node_aggregate(root_batch, h, self.embedding, self.W) return h, root_vecs
def forward(self, root_batch): orders = [] for root in root_batch: order = get_prop_order(root) orders.append(order) h = {} max_depth = max([len(x) for x in orders]) padding = create_var(torch.zeros(self.hidden_size), False) maxx=0 for t in range(max_depth): prop_list = [] for order in orders: if t < len(order): prop_list.extend(order[t]) cur_x = [] cur_h_nei = [] for node_x,node_y in prop_list: x,y = node_x.idx,node_y.idx cur_x.append(node_x.wid) h_nei = [] for node_z in node_x.neighbors: z = node_z.idx if z == y: continue h_nei.append(h[(z,x)]) if len(h_nei)>MAX_NB: print("len(h_nei)") print(len(h_nei)) if len(h_nei)>maxx: maxx=len(h_nei) pad_len = MAX_NB - len(h_nei) h_nei.extend([padding] * pad_len) cur_h_nei.extend(h_nei) #print(maxx) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #print(torch.cat(cur_h_nei, dim=0).size()) cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) for i,m in enumerate(prop_list): x,y = m[0].idx,m[1].idx h[(x,y)] = new_h[i] root_vecs = node_aggregate(root_batch, h, self.embedding, self.W) return h, root_vecs
def forward(self, atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, atom_bond_adjacency_graph, bond_atom_adjacency_graph, scope): """ Args: atom_feature_matrix: torch.tensor (shape: batch_size x atom_feature_dim) The matrix containing feature vectors, for all the atoms, across the entire batch. * atom_feature_dim = len(ELEM_LIST) + 6 + 5 + 4 + 1 bond_feature_matrix: torch.tensor (shape: batch_size x bond_feature_dim) The matrix containing feature vectors, for all the bonds, across the entire batch. * bond_feature_dim = 5 + 6 atom_adjacency_graph: torch.tensor (shape: num_atoms x MAX_NUM_NEIGHBORS(=6)) For each atom, across the entire batch, the idxs of neighboring atoms. atom_bond_adjacency_graph: torch.tensor(shape: num_atoms x MAX_NUM_NEIGHBORS(=6)) For each atom, across the entire batch, the idxs of all the bonds, in which it is the initial atom. bond_atom_adjacency_graph: torch.tensor (shape: num_bonds x 2) For each bond, across the entire batch, the idxs of the 2 atoms, of which the bond is composed of. scope: List[Tuple(int, int)] The list to store tuples (total_bonds, num_bonds), to keep track of all the bond feature vectors, belonging to a particular molecule. Returns: mol_vecs: torch.tensor (shape: batch_size x hidden_size) The hidden vector representation of each molecular graph, across the entire batch """ # create PyTorch variables atom_feature_matrix = create_var(atom_feature_matrix) bond_feature_matrix = create_var(bond_feature_matrix) atom_adjacency_graph = create_var(atom_adjacency_graph) atom_bond_adjacency_graph = create_var(atom_bond_adjacency_graph) bond_atom_adjacency_graph = create_var(bond_atom_adjacency_graph) # implement convolution atom_layer_input = atom_feature_matrix bond_layer_input = bond_feature_matrix for conv_layer in self.conv_layers: # implement forward pass for this convolutional layer atom_layer_output, bond_layer_output = conv_layer(atom_layer_input, bond_layer_input, atom_adjacency_graph, atom_bond_adjacency_graph, bond_atom_adjacency_graph) # set the input features for the next convolutional layer atom_layer_input, bond_layer_input = atom_layer_output, bond_layer_output # for each molecular graph, pool all the edge feature vectors mol_vecs = self.pool_bond_features_for_mols(atom_layer_output, bond_layer_output, bond_atom_adjacency_graph, scope) return mol_vecs
def reconstruct(self, smiles, prob_decode=False): mol_tree = MolTree(smiles) mol_tree.recover() _,tree_vec,mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al. epsilon = create_var(torch.randn(1, self.latent_size / 2), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, self.latent_size / 2), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon return self.decode(tree_vec, mol_vec, prob_decode)
def rsample(self, z_vecs, W_mean, W_var): z_mean = W_mean(z_vecs) z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al. kl_loss = -0.5 * torch.mean(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) epsilon = create_var(torch.randn_like(z_mean)) z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon return z_vecs, kl_loss
def assm(self, mol_batch, mol_vec, tree_mess): cands = [] batch_idx = [] for i, mol_tree in enumerate(mol_batch): for node in mol_tree.nodes: # Leaf node's attachment is determined by neighboring node's # attachment if node.is_leaf() or len(node.cands) == 1: continue cands.extend([(cand, mol_tree.nodes, node) for cand in node.cand_mols]) batch_idx.extend([i] * len(node.cands)) cand_vec = self.jtmpn(cands, tree_mess) cand_vec = self.G_mean(cand_vec) batch_idx = create_var(torch.LongTensor(batch_idx)) mol_vec = mol_vec.index_select(0, batch_idx) mol_vec = mol_vec.view(-1, 1, self.latent_size / 2) cand_vec = cand_vec.view(-1, self.latent_size / 2, 1) scores = torch.bmm(mol_vec, cand_vec).squeeze() cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch): comp_nodes = [ node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf() ] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score.data[label] >= cur_score.max().data[0]: acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.assm_loss(cur_score.view(1, -1), label)) # all_loss = torch.stack(all_loss).sum() / len(mol_batch) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt
def reconstruct1(self, smiles, prob_decode=False): mol_tree = MolTree(smiles) mol_tree.recover() # print("tree olusturuldu") _, tree_vec, mol_vec = self.encode([mol_tree]) # print("encode edildi") tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) # Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) # Following Mueller et al. epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon return tree_vec,mol_vec,prob_decode
def fuse_pair(self, x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, jtenc_scope, mpn_scope): diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1) size = create_var(torch.Tensor([le for _,le in jtenc_scope])) diff_tree_vecs = diff_tree_vecs / size.unsqueeze(-1) diff_mol_vecs = y_mol_vecs.sum(dim=1) - x_mol_vecs.sum(dim=1) size = create_var(torch.Tensor([le for _,le in mpn_scope])) diff_mol_vecs = diff_mol_vecs / size.unsqueeze(-1) diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var) diff_mol_vecs, mol_kl = self.rsample(diff_mol_vecs, self.G_mean, self.G_var) diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1) diff_mol_vecs = diff_mol_vecs.unsqueeze(1).expand(-1, x_mol_vecs.size(1), -1) x_tree_vecs = torch.cat([x_tree_vecs,diff_tree_vecs], dim=-1) x_mol_vecs = torch.cat([x_mol_vecs,diff_mol_vecs], dim=-1) return self.B_t(x_tree_vecs), self.B_g(x_mol_vecs), tree_kl + mol_kl
def node_aggregate(nodes, h, embedding, W): x_idx = [] h_nei = [] hidden_size = embedding.embedding_dim padding = create_var(torch.zeros(hidden_size), False) for node_x in nodes: x_idx.append(node_x.wid) nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(nei) nei.extend([padding] * pad_len) h_nei.extend(nei) h_nei = torch.cat(h_nei, dim=0).view(-1, MAX_NB, hidden_size) sum_h_nei = h_nei.sum(dim=1) x_vec = create_var(torch.LongTensor(x_idx)) x_vec = embedding(x_vec) node_vec = torch.cat([x_vec, sum_h_nei], dim=1) return nn.ReLU()(W(node_vec))
def recon_eval(self, smiles): mol_tree = MolTree(smiles) mol_tree.recover() _,tree_vec,mol_vec = self.encode([mol_tree]) tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al. all_smiles = [] for i in range(10): epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon for j in range(10): new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True) all_smiles.append(new_smiles) return all_smiles
def reconstruct(self, smiles, prob_decode=False,DataFrame=None): mol_tree = MolTree(smiles) mol_tree.recover() #print("tree olusturuldu") _,tree_vec,mol_vec = self.encode([mol_tree]) #print("encode edildi") tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al. mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al. epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False) tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False) mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon thethird=torch.cat((tree_vec, mol_vec), 1) #print(thethird.to('cpu').data.numpy()) DataFrame.loc[smiles]=thethird.to('cpu').data.numpy()[0] return self.decode(tree_vec, mol_vec, prob_decode)
def rsample(self, z_vecs, W_mean, W_var): batch_size = z_vecs.size(0) z_mean = W_mean(z_vecs) z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al. # as per Kingma & Welling kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size # reparameterization trick epsilon = create_var(torch.randn_like(z_mean)) z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon return z_vecs, kl_loss
def forward(self, node_wid_list, node_child_adjacency_graph, node_edge_adjacency_graph, edge_node_adjacency_graph, scope, root_scope): # list to store embedding vectors for junction-tree nodes node_feature_vecs = [] # padding vector for node features node_feature_padding = create_var(torch.zeros(self.hidden_size)) node_feature_vecs.append(node_feature_padding) # put this tensor on the GPU node_wid_list = create_var(node_wid_list) # obtain embedding vectors for all the junction-tree nodes node_embeddings = self.embedding(node_wid_list) node_feature_vecs.extend(list(node_embeddings)) node_feature_matrix = torch.stack(node_feature_vecs, dim=0) total_num_edges = edge_node_adjacency_graph.shape[0] edge_feature_matrix = torch.zeros(total_num_edges, self.hidden_size) # create PyTorch variables node_feature_matrix = create_var(node_feature_matrix) edge_feature_matrix = create_var(edge_feature_matrix) node_child_adjacency_graph = create_var(node_child_adjacency_graph) node_edge_adjacency_graph = create_var(node_edge_adjacency_graph) edge_node_adjacency_graph = create_var(edge_node_adjacency_graph) # implement convolution node_layer_input = node_feature_matrix edge_layer_input = edge_feature_matrix for conv_layer in self.conv_layers: # implement forward pass for this convolutional layer node_layer_output, edge_layer_output = conv_layer( node_layer_input, edge_layer_input, node_child_adjacency_graph, node_edge_adjacency_graph, edge_node_adjacency_graph) # set the input features for the next convolutional layer node_layer_input, edge_layer_input = node_layer_output, edge_layer_output # for each molecular graph, pool all the edge feature vectors # tree_vecs = self.pool_edge_features_for_junc_trees(node_layer_output, edge_layer_output, edge_node_adjacency_graph, scope) tree_vecs = node_layer_output[root_scope] return tree_vecs
def gradient_penalty(self, real_vecs, fake_vecs): eps = create_var(torch.rand(real_vecs.size(0), 1)) inter_data = eps * real_vecs + (1 - eps) * fake_vecs inter_data = autograd.Variable(inter_data, requires_grad=True) inter_score = self.netD(inter_data).squeeze(-1) inter_grad = autograd.grad(inter_score, inter_data, grad_outputs=torch.ones( inter_score.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] inter_norm = inter_grad.norm(2, dim=1) inter_gp = ((inter_norm - 1)**2).mean() * self.beta #inter_norm = (inter_grad ** 2).sum(dim=1) #inter_gp = torch.max(inter_norm - 1, self.zero).mean() * self.beta return inter_gp, inter_norm.mean().item()
def forward(self, h, x, mess_graph): mask = torch.ones(h.size(0), 1) mask[0] = 0 #first vector is padding mask = create_var(mask) for it in xrange(self.depth): h_nei = index_select_ND(h, 0, mess_graph) sum_h = h_nei.sum(dim=1) z_input = torch.cat([x, sum_h], dim=1) z = F.sigmoid(self.W_z(z_input)) r_1 = self.W_r(x).view(-1, 1, self.hidden_size) r_2 = self.U_r(h_nei) r = F.sigmoid(r_1 + r_2) gated_h = r * h_nei sum_gated_h = gated_h.sum(dim=1) h_input = torch.cat([x, sum_gated_h], dim=1) pre_h = F.tanh(self.W_h(h_input)) h = (1.0 - z) * sum_h + z * pre_h h = h * mask return h