def forward(self, root_batch): orders = [] # orders: list(list), 每个子列表代表一个根层次遍历的结果 for root in root_batch: # oder: list(list), 一个列表分为两部分, # 一个是自底向上的顺序,每个子列表中包含该层的结点及其父节点 # 一个是自顶向下的顺序,每个子列表中包含该层的结点及其子节点 order = get_prop_order(root) orders.append(order) h = {} max_depth = max([len(order) for order in orders]) padding = create_var(torch.zeros(self.hidden_size), False) for t in range(max_depth): prop_list = [] for order in orders: if len(order) > t: # 确保这棵树有第t层 prop_list.extend(order[t]) # 第t层的层次列表加入到prop_list cur_x = [] cur_h_nei = [] for node_x, node_y in prop_list: x, y = node_x.idx, node_y.idx # 结点编号 cur_x.append(node_x.wid) # 结点类型编号 h_nei = [] for node_z in node_x.neighbors: z = node_z.idx if z == y: continue # h_nei:结点x除y以外的邻居,即与其相邻的结点 h_nei.append(h[(z, x)]) # 如果邻居数量达不到最大值,则用padding的变量填充 pad_len = MAX_NB - len(h_nei) h_nei.extend([padding] * pad_len) cur_h_nei.extend(h_nei) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding( cur_x ) # 从这里开始,标签转化为了向量, cur_x.size = (len(prop_list), hidden_size) cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) # cur_nei_h.size = (len(prop_list), MAX_NB, hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) for i, m in enumerate(prop_list): x, y = m[0].idx, m[1].idx h[(x, y)] = new_h[i] # node aggregate root_vecs = node_aggregate(root_batch, h, self.embedding, self.W) return h, root_vecs
def forward(self, root_batch): orders = [] for root in root_batch: order = get_prop_order(root) orders.append(order) h = {} max_depth = max([len(x) for x in orders]) padding = create_var(torch.zeros(self.hidden_size), False) maxx=0 for t in range(max_depth): prop_list = [] for order in orders: if t < len(order): prop_list.extend(order[t]) cur_x = [] cur_h_nei = [] for node_x,node_y in prop_list: x,y = node_x.idx,node_y.idx cur_x.append(node_x.wid) h_nei = [] for node_z in node_x.neighbors: z = node_z.idx if z == y: continue h_nei.append(h[(z,x)]) if len(h_nei)>MAX_NB: print("len(h_nei)") print(len(h_nei)) if len(h_nei)>maxx: maxx=len(h_nei) pad_len = MAX_NB - len(h_nei) h_nei.extend([padding] * pad_len) cur_h_nei.extend(h_nei) #print(maxx) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #print(torch.cat(cur_h_nei, dim=0).size()) cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) for i,m in enumerate(prop_list): x,y = m[0].idx,m[1].idx h[(x,y)] = new_h[i] root_vecs = node_aggregate(root_batch, h, self.embedding, self.W) return h, root_vecs
def forward(self, root_batch): orders = [] for root in root_batch: order = get_prop_order(root) orders.append(order) h = {} max_depth = max([len(x) for x in orders]) padding = torch.zeros(self.hidden_size).to(self.device) for t in xrange(max_depth): prop_list = [] for order in orders: if t < len(order): prop_list.extend(order[t]) cur_x = [] cur_h_nei = [] for node_x, node_y in prop_list: x, y = node_x.idx, node_y.idx cur_x.append(node_x.wid) h_nei = [] for node_z in node_x.neighbors: z = node_z.idx if z == y: continue h_nei.append(h[(z, x)]) pad_len = MAX_NB - len(h_nei) h_nei.extend([padding] * pad_len) cur_h_nei.extend(h_nei) cur_x = torch.tensor(cur_x, dtype=torch.long).to(self.device) cur_x = self.embedding(cur_x) cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h, self.device) for i, m in enumerate(prop_list): x, y = m[0].idx, m[1].idx h[(x, y)] = new_h[i] root_vecs = node_aggregate(root_batch, h, self.embedding, self.W, self.device) return h, root_vecs
def forward(self, holder, depth): fnode = create_var(holder[0]) fmess = create_var(holder[1]) node_graph = create_var(holder[2]) mess_graph = create_var(holder[3]) scope = holder[4] fnode = self.embedding(fnode) x = index_select_ND(fnode, 0, fmess) h = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) mask = torch.ones(h.size(0), 1) mask[0] = 0 #first vector is padding mask = create_var(mask) for it in xrange(depth): h_nei = index_select_ND(h, 0, mess_graph) h = GRU(x, h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h = h * mask mess_nei = index_select_ND(h, 0, node_graph) node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) root_vecs = [node_vecs[st] for st, le in scope] return torch.stack(root_vecs, dim=0)
def forward(self, mol_batch, x_tree_vecs): pred_hiddens, pred_contexts, pred_targets = [], [], [] stop_hiddens, stop_contexts, stop_targets = [], [], [] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], -1) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] #Predict Root batch_size = len(mol_batch) pred_hiddens.append( create_var(torch.zeros(len(mol_batch), self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_contexts.append(create_var(torch.LongTensor(range(batch_size)))) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) h = {} for t in xrange(max_iter): prop_list = [] batch_list = [] for i, plist in enumerate(traces): if t < len(plist): prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei, cur_o_nei = [], [] for node_x, real_y, _ in prop_list: #Neighbors for message passing (target not included) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx ] pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) #Neighbors for stop prediction (all neighbors) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) #Current clique embedding cur_x.append(node_x.wid) #Clique embedding cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #Message passing cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) #Node Aggregate cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) #Gather targets pred_target, pred_list = [], [] stop_target = [] for i, m in enumerate(prop_list): node_x, node_y, direction = m x, y = node_x.idx, node_y.idx h[(x, y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) #Hidden states for stop prediction cur_batch = create_var(torch.LongTensor(batch_list)) stop_hidden = torch.cat([cur_x, cur_o], dim=1) stop_hiddens.append(stop_hidden) stop_contexts.append(cur_batch) stop_targets.extend(stop_target) #Hidden states for clique prediction if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_contexts.append(cur_batch) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append(new_h.index_select(0, cur_pred)) pred_targets.extend(pred_target) #Last stop at root cur_x, cur_o_nei = [], [] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_o], dim=1) stop_hiddens.append(stop_hidden) stop_contexts.append(create_var(torch.LongTensor(range(batch_size)))) stop_targets.extend([0] * len(mol_batch)) #Predict next clique pred_contexts = torch.cat(pred_contexts, dim=0) pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs, 'word') pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() #Predict stop stop_contexts = torch.cat(stop_contexts, dim=0) stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs, 'stop') stop_scores = stop_scores.squeeze(-1) stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def decode(self, x_tree_vecs, prob_decode): assert x_tree_vecs.size(0) == 1 stack = [] init_hiddens = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) contexts = create_var(torch.LongTensor(1).zero_()) #Root Prediction root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word') _, root_wid = torch.max(root_score, dim=1) root_wid = root_wid.item() root = MolTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append((root, self.vocab.get_slots(root.wid))) all_nodes = [root] h = {} for step in xrange(MAX_DECODE_LEN): node_x, fa_slot = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop') if prob_decode: backtrack = (torch.bernoulli( torch.sigmoid(stop_score)).item() == 0) else: backtrack = (stop_score.item() < 0) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word') if prob_decode: sort_wid = torch.multinomial( F.softmax(pred_score, dim=1).squeeze(), 5) else: _, sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolTreeNode(self.vocab.get_smiles(wid)) if have_slots(fa_slot, slots) and can_assemble( node_x, node_y): next_wid = wid next_slots = slots break if next_wid is None: backtrack = True #No more children can be added else: node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] stack.append((node_y, next_slots)) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa, _ = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, node_fa.idx)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes
def decode(self, mol_vec, prob_decode): stack, trace = [], [] init_hidden = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) #Root Prediction root_hidden = torch.cat([init_hidden, mol_vec], dim=1) root_hidden = nn.ReLU()(self.W(root_hidden)) root_score = self.W_o(root_hidden) _, root_wid = torch.max(root_score, dim=1) root_wid = root_wid.item() root = MolTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append((root, self.vocab.get_slots(root.wid))) all_nodes = [root] h = {} for step in xrange(MAX_DECODE_LEN): node_x, fa_slot = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1) stop_hidden = nn.ReLU()(self.U(stop_hidden)) stop_score = nn.Sigmoid()(self.U_s(stop_hidden) * 20).squeeze() if prob_decode: backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1) else: backtrack = (stop_score.item() < 0.5) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_hidden = torch.cat([new_h, mol_vec], dim=1) pred_hidden = nn.ReLU()(self.W(pred_hidden)) pred_score = nn.Softmax(dim=1)(self.W_o(pred_hidden) * 20) if prob_decode: sort_wid = torch.multinomial(pred_score.data.squeeze(), 5) else: _, sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolTreeNode(self.vocab.get_smiles(wid)) if have_slots(fa_slot, slots) and can_assemble( node_x, node_y): next_wid = wid next_slots = slots break if next_wid is None: backtrack = True #No more children can be added else: node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = step + 1 node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] stack.append((node_y, next_slots)) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa, _ = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, node_fa.idx)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes
def forward(self, mol_batch, mol_vec): super_root = MolTreeNode('') super_root.idx = -1 # 初始化 pred_hiddens, pred_mol_vecs, pred_targets = [], [], [] stop_hiddens, stop_targets = [], [] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], super_root) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] pred_hiddens.append( create_var(torch.zeros(len(mol_batch), self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_mol_vecs.append(mol_vec) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) h = {} for t in range(max_iter): prop_list = [] batch_list = [] for i, plist in enumerate(traces): if len(plist) > t: prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei, cur_o_nei = [], [] for node_x, real_y, _ in prop_list: # cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx] cur_nei = [] for node_y in node_x.neighbors: if node_y.idx != real_y.idx: ht = h[(node_y.idx, node_x.idx)] print(ht) cur_nei.append(ht) pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) cur_nei = [ h[node_y.idx, node_x.idx] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x.append(node_x.wid) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) print(len(cur_h_nei)) print(cur_h_nei[0].shape) cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) print(cur_x.shape) print(cur_h_nei.shape) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) pred_target, pred_list = [], [] stop_target = [] for i, m in enumerate(prop_list): node_x, node_y, direction = m x, y = node_x.idx, node_y.idx h[(x, y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) cur_batch = create_var(torch.LongTensor(batch_list)) cur_mol_vec = mol_vec.index_select(0, cur_batch) stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(stop_target) if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_mol_vecs.append(mol_vec.index_select(0, cur_batch)) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append(new_h.index_select(0, cur_pred)) pred_targets.extend(pred_target) cur_x, cur_o_nei = [], [] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend([0] * len(mol_batch)) pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0) pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1) pred_vecs = nn.ReLU()(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_vecs = nn.ReLU()(self.U(stop_hiddens)) stop_scores = self.U_s(stop_vecs).squeeze() stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def decode(self, x_tree_vecs, prob_decode): assert x_tree_vecs.size(0) == 1 stack = [] init_hiddens = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) contexts = create_var(torch.LongTensor(1).zero_()) #Root Prediction root_feature = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word') # _,root_wid = torch.max(root_score, dim=1) # root_wid = root_wid.item() root = TreeNode(root_feature) root.idx = 0 root.graphid = 0 stack.append((root, root_feature)) all_nodes = [root] h = {} for step in xrange(MAX_DECODE_LEN): node_x, fa_slot = stack[-1] cur_h_nei = [ h[(node_y.graphid, node_x.graphid)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad # todo cur_x = create_var(torch.LongTensor([node_x.feature])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop') if prob_decode: backtrack = (torch.bernoulli( torch.sigmoid(stop_score)).item() == 0) else: backtrack = (stop_score.item() < 0) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_feature = self.aggregate(new_h, contexts, x_tree_vecs, 'word') node_y = TreeNode(pred_feature) node_y.idx = len(all_nodes) node_y.graphid = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.graphid, node_y.graphid)] = new_h[0] stack.append((node_y, pred_feature)) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa, _ = stack[-2] cur_h_nei = [ h[(node_y.graphid, node_x.graphid)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.graphid, node_fa.graphid)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes
def decode(self, tree_vecs): """ Description: Given the tree vector, predict the corresponding junction-tree. Args: tree_vecs: torch.tensor (shape: hidden_size) Returns: root: MolJuncTreeNode The root node of the decoded junction-tree. all_nodes: List[MolJuncTreeNode] The list of all the nodes in the decoded junction-tree. """ assert tree_vecs.size(0) == 1 stack = [] init_hiddens = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) contexts = create_var(torch.LongTensor(1).zero_()) # root Prediction root_score = self.aggregate(init_hiddens, contexts, tree_vecs, 'word') _, root_wid = torch.max(root_score, dim=1) root_wid = root_wid.item() root = MolJuncTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append((root, self.vocab.get_slots(root.wid))) all_nodes = [root] h = {} for step in range(MAX_DECODE_LEN): node_x, parent_slot = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.vocab_embedding(cur_x) # Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hidden)) stop_score = self.aggregate(stop_hiddens, contexts, tree_vecs, 'stop') backtrack = (stop_score.item() < 0) # go down forward: predict next cluster if not backtrack: new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.aggregate(new_h, contexts, tree_vecs, 'word') _, sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolJuncTreeNode(self.vocab.get_smiles(wid)) if self.have_slots(parent_slot, slots) and self.can_assemble( node_x, node_y): next_wid = wid next_slots = slots break # no more children can be added if next_wid is None: backtrack = True else: node_y = MolJuncTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] stack.append((node_y, next_slots)) all_nodes.append(node_y) # backtrack, use if instead of else if backtrack: if len(stack) == 1: # back to root, terminate break parent_node, _ = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != parent_node.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, parent_node.idx)] = new_h[0] parent_node.neighbors.append(node_x) stack.pop() return root, all_nodes
def forward(self, junc_tree_batch, tree_vecs): """ Args: junc_tree_batch: List[MolJuncTree] The list of junction-trees for all the molecules, across the entire dataset. tree_vecs: torch.tensor (shape: batch_size x hidden_size) The vector represenations of all the junction-trees, for all the molecules, across the entire dataset. """ # initialize pred_hiddens, pred_contexts, pred_targets = [], [], [] stop_hiddens, stop_contexts, stop_targets = [], [], [] # list to store dfs traversals for molecular trees of all molecules traces = [] for junc_tree in junc_tree_batch: stack = [] # root node has no parent node, # so we use a virtual node with idx = -1 self.dfs(stack, junc_tree.nodes[0], -1) traces.append(stack) for node in junc_tree.nodes: node.neighbors = [] # predict root batch_size = len(junc_tree_batch) pred_hiddens.append( create_var(torch.zeros(batch_size, self.hidden_size))) # list of indices of cluster vocabulary items, # for the root node of all junction trees, across the entire dataset. pred_targets.extend( [junc_tree.nodes[0].wid for junc_tree in junc_tree_batch]) pred_contexts.append(create_var(torch.LongTensor(range(batch_size)))) # number of traversals to go through, to ensure that dfs traversal is completed for the # junction-tree with the largest size / height. max_iter = max([len(tr) for tr in traces]) # padding vector for putting in place of messages from non-existant neighbors padding = create_var(torch.zeros(self.hidden_size), False) # dictionary to store hidden edge message vectors h = {} for iter in range(max_iter): # list to store edge tuples that will be considered in this iteration. edge_tuple_list = [] # batch id of all junc_trees / tree_vecs whose edge_tuple # in being considered in this timestep batch_list = [] for idx, dfs_traversal in enumerate(traces): # keep appending traversal orders for a particular depth level, # from a given traversal_order list, # until the list is not empty if iter < len(dfs_traversal): edge_tuple_list.append(dfs_traversal[iter]) batch_list.append(idx) cur_x = [] cur_h_nei, cur_o_nei = [], [] for node_x, real_y, _ in edge_tuple_list: # neighbors for message passing (target not included) # hidden edge message vectors from predecessor neighbor nodes cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx ] # a node can at max MAX_NUM_NEIGHBORS(=15) neighbors # if it has less neighbors, then we append vector of zeros as messages from non-existent neighbors pad_len = MAX_NUM_NEIGHBORS - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) # neighbors for stop (topological) prediction (all neighbors) # hidden edge messages from all neighbor nodes cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NUM_NEIGHBORS - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) # current cluster embedding cur_x.append(node_x.wid) # cluster embedding cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.vocab_embedding(cur_x) # implement message passing cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NUM_NEIGHBORS, self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) # node aggregate cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NUM_NEIGHBORS, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) # gather targets pred_target, pred_list = [], [] stop_target = [] # teacher forcing for idx, edge_tuple in enumerate(edge_tuple_list): node_x, node_y, direction = edge_tuple x, y = node_x.idx, node_y.idx h[(x, y)] = new_h[idx] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(idx) stop_target.append(direction) # hidden states for stop (topological) prediction cur_batch = create_var(torch.LongTensor(batch_list)) stop_hidden = torch.cat([cur_x, cur_o], dim=1) stop_hiddens.append(stop_hidden) stop_contexts.append(cur_batch) stop_targets.extend(stop_target) # hidden states for cluster prediction if len(pred_list) > 0: batch_list = [batch_list[idx] for idx in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_contexts.append(cur_batch) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append(new_h.index_select(0, cur_pred)) pred_targets.extend(pred_target) # last stop at root cur_x, cur_o_nei = [], [] for junc_tree in junc_tree_batch: node_x = junc_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NUM_NEIGHBORS - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.vocab_embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NUM_NEIGHBORS, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_o], dim=1) stop_hiddens.append(stop_hidden) stop_contexts.append(create_var(torch.LongTensor(range(batch_size)))) stop_targets.extend([0] * len(junc_tree_batch)) # predict next cluster pred_contexts = torch.cat(pred_contexts, dim=0) pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_scores = self.aggregate(pred_hiddens, pred_contexts, tree_vecs, 'word') pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(junc_tree_batch) _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() # predict stop stop_contexts = torch.cat(stop_contexts, dim=0) stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_scores = self.aggregate(stop_hiddens, stop_contexts, tree_vecs, 'stop') stop_scores = stop_scores.squeeze(-1) stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(junc_tree_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def soft_decode(self, x_tree_vecs, x_mol_vecs, gumbel, slope, temp): assert x_tree_vecs.size(0) == 1 soft_embedding = lambda x: x.matmul(self.embedding.weight) if gumbel: sample_softmax = lambda x: F.gumbel_softmax(x, tau=temp) else: sample_softmax = lambda x: F.softmax(x / temp, dim=1) stack = [] init_hiddens = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) contexts = create_var(torch.LongTensor(1).zero_()) #Root Prediction root_score = self.attention(init_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'word') root_prob = sample_softmax(root_score) root = MolTreeNode("") root.embedding = soft_embedding(root_prob) root.prob = root_prob root.idx = 0 stack.append(root) all_nodes = [root] all_hiddens = [] h = {} for step in xrange(MAX_SOFT_DECODE_LEN): node_x = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad #Predict stop cur_x = node_x.embedding cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'stop') all_hiddens.append(stop_hiddens) forward = 0 if stop_score.item() < 0 else 1 stop_prob = F.hardtanh(slope * stop_score + 0.5, min_val=0, max_val=1).unsqueeze(1) stop_val_ste = forward + stop_prob - stop_prob.detach() if forward == 1: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.attention(new_h, contexts, x_tree_vecs, x_mol_vecs, 'word') pred_prob = sample_softmax(pred_score) node_y = MolTreeNode("") node_y.embedding = soft_embedding(pred_prob) node_y.prob = pred_prob node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] * stop_val_ste stack.append(node_y) all_nodes.append(node_y) else: if len(stack) == 1: #At root, terminate return torch.cat([cur_x, cur_h], dim=1), all_nodes node_fa = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, node_fa.idx)] = new_h[0] * (1.0 - stop_val_ste) node_fa.neighbors.append(node_x) stack.pop() #Failure mode: decoding unfinished cur_h_nei = [h[(node_y.idx, root.idx)] for node_y in root.neighbors] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([root.embedding, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) all_hiddens.append(stop_hiddens) return torch.cat([root.embedding, cur_h], dim=1), all_nodes
def forward(self, mol_batch, x_tree_vecs, x_mol_vecs, origin_word): """ mol_batch: Y; where (X,Y), X is input, Y is target. """ pred_hiddens,pred_contexts,pred_targets = [],[],[] stop_hiddens,stop_contexts,stop_targets = [],[],[] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], -1) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] #Predict Root batch_size = len(mol_batch) pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) ) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) ### 0 h = {} for t in xrange(max_iter): prop_list = [] batch_list = [] for i,plist in enumerate(traces): if t < len(plist): prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei,cur_o_nei = [],[] for node_x, real_y, _ in prop_list: #Neighbors for message passing (target not included) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx] pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) #Neighbors for stop prediction (all neighbors) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) #Current clique embedding cur_x.append(node_x.wid) #Clique embedding cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #Message passing cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) #Node Aggregate cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size) cur_o = cur_o_nei.sum(dim=1) #Gather targets pred_target,pred_list = [],[] stop_target = [] for i,m in enumerate(prop_list): node_x,node_y,direction = m x,y = node_x.idx,node_y.idx h[(x,y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) #Hidden states for stop prediction cur_batch = create_var(torch.LongTensor(batch_list)) stop_hidden = torch.cat([cur_x,cur_o], dim=1) stop_hiddens.append( stop_hidden ) stop_contexts.append( cur_batch ) stop_targets.extend( stop_target ) #Hidden states for clique prediction if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_contexts.append( cur_batch ) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append( new_h.index_select(0, cur_pred) ) pred_targets.extend( pred_target ) #Last stop at root cur_x,cur_o_nei = [],[] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x,cur_o], dim=1) stop_hiddens.append( stop_hidden ) stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) ) stop_targets.extend( [0] * len(mol_batch) ) #Predict next clique pred_contexts = torch.cat(pred_contexts, dim=0) pred_hiddens = torch.cat(pred_hiddens, dim=0) #pred_scores = self.attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, 'word') pred_scores = self.mixture_attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, origin_word) if torch.isnan(pred_scores).any(): print "forward nan" ''' pickle.dump((pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, pred_scores, pred_targets, origin_word),\ open("tmp.pkl", "wb")) print "save ok" exit()''' ''' pred_hiddens: (420, 300) 420 target pred_contexts: (420,) int x_tree_vecs: (32, 21, 300) x_mol_vecs: (32, 28, 300) pred_scores: (420, 780) pred_targets: [517, 517, 516, 516, 523, 517, 517, 517, 517, 477, ...] len()=420; origin_word: input import pickle pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, pred_scores, pred_targets, origin_word = pickle.load(open("tmp.pkl", "rb")) ''' pred_targets = create_var(torch.LongTensor(pred_targets)) #pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) pred_loss = self.nllloss(torch.log( torch.max(pred_scores,\ torch.ones_like(pred_scores) * minimum_threshold_before_log) \ ),\ pred_targets) / len(mol_batch) _,preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() #Predict stop stop_contexts = torch.cat(stop_contexts, dim=0) stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_hiddens = F.relu( self.U_i(stop_hiddens) ) stop_scores = self.attention(stop_hiddens, stop_contexts, x_tree_vecs, x_mol_vecs, 'stop') stop_scores = stop_scores.squeeze(-1) stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()