def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0) fatoms = create_var(fatoms) fbonds = create_var(fbonds) agraph = create_var(agraph) bgraph = create_var(bgraph) binput = self.W_i(fbonds) graph_message = F.relu(binput) for i in range(self.depth - 1): message = torch.cat([tree_message, graph_message], dim=0) nei_message = index_select_ND(message, 0, bgraph) nei_message = nei_message.sum( dim=1) #assuming tree_message[0] == vec(0) nei_message = self.W_h(nei_message) graph_message = F.relu(binput + nei_message) message = torch.cat([tree_message, graph_message], dim=0) nei_message = index_select_ND(message, 0, agraph) nei_message = nei_message.sum(dim=1) ainput = torch.cat([fatoms, nei_message], dim=1) atom_hiddens = F.relu(self.W_o(ainput)) mol_vecs = [] for st, le in scope: mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le mol_vecs.append(mol_vec) mol_vecs = torch.stack(mol_vecs, dim=0) return mol_vecs
def forward(self, fnode, fmess, node_graph, mess_graph, scope): fnode = create_var(fnode) fmess = create_var(fmess) node_graph = create_var(node_graph) mess_graph = create_var(mess_graph) messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) fnode = self.embedding(fnode) fmess = index_select_ND(fnode, 0, fmess) messages = self.GRU(messages, fmess, mess_graph) mess_nei = index_select_ND(messages, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) node_vecs = self.outputNN(node_vecs) max_len = max([x for _, x in scope]) batch_vecs = [] for st, le in scope: cur_vecs = node_vecs[st:st + le] cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len - le)) batch_vecs.append(cur_vecs) tree_vecs = torch.stack(batch_vecs, dim=0) return tree_vecs, messages
def fuse_noise(self, tree_vecs, mol_vecs): tree_eps = create_var( torch.randn(tree_vecs.size(0), 1, self.rand_size_half)) tree_eps = tree_eps.expand(-1, tree_vecs.size(1), -1) mol_eps = create_var( torch.randn(mol_vecs.size(0), 1, self.rand_size_half)) mol_eps = mol_eps.expand(-1, mol_vecs.size(1), -1) tree_vecs = torch.cat([tree_vecs, tree_eps], dim=-1) mol_vecs = torch.cat([mol_vecs, mol_eps], dim=-1) return self.B_t(tree_vecs), self.B_g(mol_vecs)
def rsample(self, z_vecs, W_mean, W_var): z_mean = W_mean(z_vecs) z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al. kl_loss = -0.5 * torch.mean(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) epsilon = create_var(torch.randn_like(z_mean)) z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon return z_vecs, kl_loss
def fuse_pair(self, x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, jtenc_scope, mpn_scope): diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1) size = create_var(torch.Tensor([le for _, le in jtenc_scope])) diff_tree_vecs = diff_tree_vecs / size.unsqueeze(-1) diff_mol_vecs = y_mol_vecs.sum(dim=1) - x_mol_vecs.sum(dim=1) size = create_var(torch.Tensor([le for _, le in mpn_scope])) diff_mol_vecs = diff_mol_vecs / size.unsqueeze(-1) diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var) diff_mol_vecs, mol_kl = self.rsample(diff_mol_vecs, self.G_mean, self.G_var) diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand( -1, x_tree_vecs.size(1), -1) diff_mol_vecs = diff_mol_vecs.unsqueeze(1).expand( -1, x_mol_vecs.size(1), -1) x_tree_vecs = torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) x_mol_vecs = torch.cat([x_mol_vecs, diff_mol_vecs], dim=-1) return self.B_t(x_tree_vecs), self.B_g(x_mol_vecs), tree_kl + mol_kl
def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, y_tree_mess): jtmpn_holder, batch_idx = jtmpn_holder fatoms, fbonds, agraph, bgraph, scope = jtmpn_holder batch_idx = create_var(batch_idx) cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess) x_mol_vecs = x_mol_vecs.sum(dim=1) #average pooling? x_mol_vecs = x_mol_vecs.index_select(0, batch_idx) x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear scores = torch.bmm(x_mol_vecs.unsqueeze(1), cand_vecs.unsqueeze(-1)).squeeze() cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch): comp_nodes = [ node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf ] cnt += len(comp_nodes) for node in comp_nodes: label = node.cands.index(node.label) ncand = len(node.cands) cur_score = scores.narrow(0, tot, ncand) tot += ncand if cur_score.data[label] >= cur_score.max().item(): acc += 1 label = create_var(torch.LongTensor([label])) all_loss.append(self.assm_loss(cur_score.view(1, -1), label)) all_loss = sum(all_loss) / len(mol_batch) return all_loss, acc * 1.0 / cnt
def gradient_penalty(self, real_vecs, fake_vecs): eps = create_var(torch.rand(real_vecs.size(0), 1)) inter_data = eps * real_vecs + (1 - eps) * fake_vecs inter_data = autograd.Variable(inter_data, requires_grad=True) inter_score = self.netD(inter_data).squeeze(-1) inter_grad = autograd.grad(inter_score, inter_data, grad_outputs=torch.ones( inter_score.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] inter_norm = inter_grad.norm(2, dim=1) inter_gp = ((inter_norm - 1)**2).mean() * self.beta #inter_norm = (inter_grad ** 2).sum(dim=1) #inter_gp = torch.max(inter_norm - 1, self.zero).mean() * self.beta return inter_gp, inter_norm.mean().item()
def forward(self, h, x, mess_graph): mask = torch.ones(h.size(0), 1) mask[0] = 0 #first vector is padding mask = create_var(mask) for it in range(self.depth): h_nei = index_select_ND(h, 0, mess_graph) sum_h = h_nei.sum(dim=1) z_input = torch.cat([x, sum_h], dim=1) z = torch.sigmoid(self.W_z(z_input)) r_1 = self.W_r(x).view(-1, 1, self.hidden_size) r_2 = self.U_r(h_nei) r = torch.sigmoid(r_1 + r_2) gated_h = r * h_nei sum_gated_h = gated_h.sum(dim=1) h_input = torch.cat([x, sum_gated_h], dim=1) pre_h = torch.tanh(self.W_h(h_input)) h = (1.0 - z) * sum_h + z * pre_h h = h * mask return h
def forward(self, holder, depth): fnode = create_var(holder[0]) fmess = create_var(holder[1]) node_graph = create_var(holder[2]) mess_graph = create_var(holder[3]) scope = holder[4] fnode = self.embedding(fnode) x = index_select_ND(fnode, 0, fmess) h = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) mask = torch.ones(h.size(0), 1) mask[0] = 0 #first vector is padding mask = create_var(mask) for it in range(depth): h_nei = index_select_ND(h, 0, mess_graph) h = GRU(x, h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h = h * mask mess_nei = index_select_ND(h, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) root_vecs = [node_vecs[st] for st, le in scope] return torch.stack(root_vecs, dim=0)
def forward(self, mol_batch, x_tree_vecs, x_mol_vecs): pred_hiddens,pred_contexts,pred_targets = [],[],[] stop_hiddens,stop_contexts,stop_targets = [],[],[] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], -1) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] #Predict Root batch_size = len(mol_batch) pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) ) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) h = {} for t in range(max_iter): prop_list = [] batch_list = [] for i,plist in enumerate(traces): if t < len(plist): prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei,cur_o_nei = [],[] for node_x, real_y, _ in prop_list: #Neighbors for message passing (target not included) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx] pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) #Neighbors for stop prediction (all neighbors) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) #Current clique embedding cur_x.append(node_x.wid) #Clique embedding cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #Message passing cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) #Node Aggregate cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size) cur_o = cur_o_nei.sum(dim=1) #Gather targets pred_target,pred_list = [],[] stop_target = [] for i,m in enumerate(prop_list): node_x,node_y,direction = m x,y = node_x.idx,node_y.idx h[(x,y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) #Hidden states for stop prediction cur_batch = create_var(torch.LongTensor(batch_list)) stop_hidden = torch.cat([cur_x,cur_o], dim=1) stop_hiddens.append( stop_hidden ) stop_contexts.append( cur_batch ) stop_targets.extend( stop_target ) #Hidden states for clique prediction if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_contexts.append( cur_batch ) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append( new_h.index_select(0, cur_pred) ) pred_targets.extend( pred_target ) #Last stop at root cur_x,cur_o_nei = [],[] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x,cur_o], dim=1) stop_hiddens.append( stop_hidden ) stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) ) stop_targets.extend( [0] * len(mol_batch) ) #Predict next clique pred_contexts = torch.cat(pred_contexts, dim=0) pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_scores = self.attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, 'word') pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) _,preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() #Predict stop stop_contexts = torch.cat(stop_contexts, dim=0) stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_hiddens = F.relu( self.U_i(stop_hiddens) ) stop_scores = self.attention(stop_hiddens, stop_contexts, x_tree_vecs, x_mol_vecs, 'stop') stop_scores = stop_scores.squeeze(-1) stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def soft_decode(self, x_tree_vecs, x_mol_vecs, gumbel, slope, temp): assert x_tree_vecs.size(0) == 1 soft_embedding = lambda x: x.matmul(self.embedding.weight) if gumbel: sample_softmax = lambda x: F.gumbel_softmax(x, tau=temp) else: sample_softmax = lambda x: F.softmax(x / temp, dim=1) stack = [] init_hiddens = create_var( torch.zeros(1, self.hidden_size) ) zero_pad = create_var(torch.zeros(1,1,self.hidden_size)) contexts = create_var( torch.LongTensor(1).zero_() ) #Root Prediction root_score = self.attention(init_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'word') root_prob = sample_softmax(root_score) root = MolTreeNode("") root.embedding = soft_embedding(root_prob) root.prob = root_prob root.idx = 0 stack.append(root) all_nodes = [root] all_hiddens = [] h = {} for step in range(MAX_SOFT_DECODE_LEN): node_x = stack[-1] cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad #Predict stop cur_x = node_x.embedding cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x,cur_h], dim=1) stop_hiddens = F.relu( self.U_i(stop_hiddens) ) stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'stop') all_hiddens.append( stop_hiddens ) forward = 0 if stop_score.item() < 0 else 1 stop_prob = F.hardtanh(slope * stop_score + 0.5, min_val=0, max_val=1).unsqueeze(1) stop_val_ste = forward + stop_prob - stop_prob.detach() if forward == 1: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.attention(new_h, contexts, x_tree_vecs, x_mol_vecs, 'word') pred_prob = sample_softmax(pred_score) node_y = MolTreeNode("") node_y.embedding = soft_embedding(pred_prob) node_y.prob = pred_prob node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx,node_y.idx)] = new_h[0] * stop_val_ste stack.append(node_y) all_nodes.append(node_y) else: if len(stack) == 1: #At root, terminate return torch.cat([cur_x,cur_h], dim=1), all_nodes node_fa = stack[-2] cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx,node_fa.idx)] = new_h[0] * (1.0 - stop_val_ste) node_fa.neighbors.append(node_x) stack.pop() #Failure mode: decoding unfinished cur_h_nei = [ h[(node_y.idx,root.idx)] for node_y in root.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([root.embedding, cur_h], dim=1) stop_hiddens = F.relu( self.U_i(stop_hiddens) ) all_hiddens.append( stop_hiddens ) return torch.cat([root.embedding,cur_h], dim=1), all_nodes
def decode(self, x_tree_vecs, x_mol_vecs): assert x_tree_vecs.size(0) == 1 stack = [] init_hiddens = create_var( torch.zeros(1, self.hidden_size) ) zero_pad = create_var(torch.zeros(1,1,self.hidden_size)) contexts = create_var( torch.LongTensor(1).zero_() ) #Root Prediction root_score = self.attention(init_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'word') _,root_wid = torch.max(root_score, dim=1) root_wid = root_wid.item() root = MolTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append( (root, self.vocab.get_slots(root.wid)) ) all_nodes = [root] h = {} for step in range(MAX_DECODE_LEN): node_x,fa_slot = stack[-1] cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x,cur_h], dim=1) stop_hiddens = F.relu( self.U_i(stop_hiddens) ) stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'stop') backtrack = (stop_score.item() < 0) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.attention(new_h, contexts, x_tree_vecs, x_mol_vecs, 'word') _,sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolTreeNode(self.vocab.get_smiles(wid)) if have_slots(fa_slot, slots) and can_assemble(node_x, node_y): next_wid = wid next_slots = slots break if next_wid is None: backtrack = True #No more children can be added else: node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx,node_y.idx)] = new_h[0] stack.append( (node_y,next_slots) ) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa,_ = stack[-2] cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx,node_fa.idx)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes