def decode(self, x_tree_vecs, prob_decode): assert x_tree_vecs.size(0) == 1 stack = [] init_hiddens = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) contexts = create_var(torch.LongTensor(1).zero_()) #Root Prediction root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word') _, root_wid = torch.max(root_score, dim=1) root_wid = root_wid.item() root = MolTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append((root, self.vocab.get_slots(root.wid))) all_nodes = [root] h = {} for step in xrange(MAX_DECODE_LEN): node_x, fa_slot = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop') if prob_decode: backtrack = (torch.bernoulli( torch.sigmoid(stop_score)).item() == 0) else: backtrack = (stop_score.item() < 0) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word') if prob_decode: sort_wid = torch.multinomial( F.softmax(pred_score, dim=1).squeeze(), 5) else: _, sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolTreeNode(self.vocab.get_smiles(wid)) if have_slots(fa_slot, slots) and can_assemble( node_x, node_y): next_wid = wid next_slots = slots break if next_wid is None: backtrack = True #No more children can be added else: node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] stack.append((node_y, next_slots)) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa, _ = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, node_fa.idx)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes
def get_subtree(tree, edge, x_node_vecs, x_mess_dict): subtree_list = {} node_tree_idx = {} node_list = {} # ========================= Get Subtree List =============================== tree.nodes[0].keep_neighbors = [] for i, node in enumerate(tree.nodes[1:]): fa_node = node.fa_node node_idx = node.idx idx = x_mess_dict[node.fa_node.idx, node.idx] if not edge[idx]: if fa_node in node_tree_idx: new_node = MolTreeNode(node.smiles) new_node.wid = node.wid new_node.neighbors = [node.fa_node.cnode] new_node.idx = node_idx node.cnode = new_node node.fa_node.cnode.neighbors.append(new_node) node_tree_idx[node] = node_tree_idx[node.fa_node] tree_node = node_tree_idx[node.fa_node] subtree_list[tree_node].add_node(new_node) else: new_fa_node = MolTreeNode(node.fa_node.smiles) new_fa_node.wid = fa_node.wid new_fa_node.idx = fa_node.idx new_node = MolTreeNode(node.smiles) new_node.wid = node.wid new_node.idx = node_idx new_fa_node.neighbors = [new_node] new_node.neighbors = [new_fa_node] node.cnode = new_node node.fa_node.cnode = new_fa_node subtree_list[new_fa_node] = Subtree(new_fa_node) subtree_list[new_fa_node].add_node(new_node) node_tree_idx[fa_node] = new_fa_node node_tree_idx[node] = new_fa_node if node.fa_node.wid in node_list: node_list[node.fa_node.wid].append((new_fa_node, new_fa_node)) else: node_list[node.fa_node.wid] = [(new_fa_node, new_fa_node)] fa_node = node_tree_idx[node] if node.wid in node_list: node_list[node.wid].append((fa_node, node)) else: node_list[node.wid] = [(fa_node, node)] # ========================= Subtree Embedding ============================== max_idx, max_num = 0, 0 if len(subtree_list) > 1: for idx in subtree_list: if len(subtree_list[idx].nodes) > max_num: max_num = len(subtree_list[idx].nodes) max_idx = idx max_subtree = subtree_list[max_idx] else: max_subtree = subtree_list[list(subtree_list.keys())[0]] for i, node in enumerate(max_subtree.nodes): node.idx = i node.nid = i return subtree_list, max_subtree, node_tree_idx, node_list
def forward(self, mol_batch, mol_vec): super_root = MolTreeNode("") super_root.idx = -1 #Initialize pred_hiddens, pred_mol_vecs, pred_targets = [], [], [] stop_hiddens, stop_targets = [], [] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], super_root) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] #Predict Root pred_hiddens.append( create_var(torch.zeros(len(mol_batch), self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_mol_vecs.append(mol_vec) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) h = {} for t in xrange(max_iter): prop_list = [] batch_list = [] for i, plist in enumerate(traces): if t < len(plist): prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei, cur_o_nei = [], [] for node_x, real_y, _ in prop_list: #Neighbors for message passing (target not included) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx ] pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) #Neighbors for stop prediction (all neighbors) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) #Current clique embedding cur_x.append(node_x.wid) #Clique embedding cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) #Message passing cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) #Node Aggregate cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) #Gather targets pred_target, pred_list = [], [] stop_target = [] for i, m in enumerate(prop_list): node_x, node_y, direction = m x, y = node_x.idx, node_y.idx h[(x, y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) #Hidden states for stop prediction cur_batch = create_var(torch.LongTensor(batch_list)) cur_mol_vec = mol_vec.index_select(0, cur_batch) stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(stop_target) #Hidden states for clique prediction if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_mol_vecs.append(mol_vec.index_select(0, cur_batch)) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append(new_h.index_select(0, cur_pred)) pred_targets.extend(pred_target) #Last stop at root cur_x, cur_o_nei = [], [] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend([0] * len(mol_batch)) #Predict next clique pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0) pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1) pred_vecs = nn.ReLU()(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() #Predict stop stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_vecs = nn.ReLU()(self.U(stop_hiddens)) stop_scores = self.U_s(stop_vecs).squeeze() stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def get_trace(self, node): super_root = MolTreeNode("") super_root.idx = -1 trace = [] dfs(trace, node, super_root) return [(x.smiles, y.smiles, z) for x, y, z in trace]
def decode(self, mol_vec, prob_decode): stack, trace = [], [] init_hidden = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) #Root Prediction root_hidden = torch.cat([init_hidden, mol_vec], dim=1) root_hidden = nn.ReLU()(self.W(root_hidden)) root_score = self.W_o(root_hidden) _, root_wid = torch.max(root_score, dim=1) root_wid = root_wid.item() root = MolTreeNode(self.vocab.get_smiles(root_wid)) root.wid = root_wid root.idx = 0 stack.append((root, self.vocab.get_slots(root.wid))) all_nodes = [root] h = {} for step in xrange(MAX_DECODE_LEN): node_x, fa_slot = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_x = create_var(torch.LongTensor([node_x.wid])) cur_x = self.embedding(cur_x) #Predict stop cur_h = cur_h_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1) stop_hidden = nn.ReLU()(self.U(stop_hidden)) stop_score = nn.Sigmoid()(self.U_s(stop_hidden) * 20).squeeze() if prob_decode: backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1) else: backtrack = (stop_score.item() < 0.5) if not backtrack: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_hidden = torch.cat([new_h, mol_vec], dim=1) pred_hidden = nn.ReLU()(self.W(pred_hidden)) pred_score = nn.Softmax(dim=1)(self.W_o(pred_hidden) * 20) if prob_decode: sort_wid = torch.multinomial(pred_score.data.squeeze(), 5) else: _, sort_wid = torch.sort(pred_score, dim=1, descending=True) sort_wid = sort_wid.data.squeeze() next_wid = None for wid in sort_wid[:5]: slots = self.vocab.get_slots(wid) node_y = MolTreeNode(self.vocab.get_smiles(wid)) if have_slots(fa_slot, slots) and can_assemble( node_x, node_y): next_wid = wid next_slots = slots break if next_wid is None: backtrack = True #No more children can be added else: node_y = MolTreeNode(self.vocab.get_smiles(next_wid)) node_y.wid = next_wid node_y.idx = step + 1 node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] stack.append((node_y, next_slots)) all_nodes.append(node_y) if backtrack: #Backtrack, use if instead of else if len(stack) == 1: break #At root, terminate node_fa, _ = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, node_fa.idx)] = new_h[0] node_fa.neighbors.append(node_x) stack.pop() return root, all_nodes
def forward(self, mol_batch, mol_vec): super_root = MolTreeNode('') super_root.idx = -1 # 初始化 pred_hiddens, pred_mol_vecs, pred_targets = [], [], [] stop_hiddens, stop_targets = [], [] traces = [] for mol_tree in mol_batch: s = [] dfs(s, mol_tree.nodes[0], super_root) traces.append(s) for node in mol_tree.nodes: node.neighbors = [] pred_hiddens.append( create_var(torch.zeros(len(mol_batch), self.hidden_size))) pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch]) pred_mol_vecs.append(mol_vec) max_iter = max([len(tr) for tr in traces]) padding = create_var(torch.zeros(self.hidden_size), False) h = {} for t in range(max_iter): prop_list = [] batch_list = [] for i, plist in enumerate(traces): if len(plist) > t: prop_list.append(plist[t]) batch_list.append(i) cur_x = [] cur_h_nei, cur_o_nei = [], [] for node_x, real_y, _ in prop_list: # cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx] cur_nei = [] for node_y in node_x.neighbors: if node_y.idx != real_y.idx: ht = h[(node_y.idx, node_x.idx)] print(ht) cur_nei.append(ht) pad_len = MAX_NB - len(cur_nei) cur_h_nei.extend(cur_nei) cur_h_nei.extend([padding] * pad_len) cur_nei = [ h[node_y.idx, node_x.idx] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x.append(node_x.wid) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) print(len(cur_h_nei)) print(cur_h_nei[0].shape) cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size) print(cur_x.shape) print(cur_h_nei.shape) new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) pred_target, pred_list = [], [] stop_target = [] for i, m in enumerate(prop_list): node_x, node_y, direction = m x, y = node_x.idx, node_y.idx h[(x, y)] = new_h[i] node_y.neighbors.append(node_x) if direction == 1: pred_target.append(node_y.wid) pred_list.append(i) stop_target.append(direction) cur_batch = create_var(torch.LongTensor(batch_list)) cur_mol_vec = mol_vec.index_select(0, cur_batch) stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend(stop_target) if len(pred_list) > 0: batch_list = [batch_list[i] for i in pred_list] cur_batch = create_var(torch.LongTensor(batch_list)) pred_mol_vecs.append(mol_vec.index_select(0, cur_batch)) cur_pred = create_var(torch.LongTensor(pred_list)) pred_hiddens.append(new_h.index_select(0, cur_pred)) pred_targets.extend(pred_target) cur_x, cur_o_nei = [], [] for mol_tree in mol_batch: node_x = mol_tree.nodes[0] cur_x.append(node_x.wid) cur_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] pad_len = MAX_NB - len(cur_nei) cur_o_nei.extend(cur_nei) cur_o_nei.extend([padding] * pad_len) cur_x = create_var(torch.LongTensor(cur_x)) cur_x = self.embedding(cur_x) cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size) cur_o = cur_o_nei.sum(dim=1) stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1) stop_hiddens.append(stop_hidden) stop_targets.extend([0] * len(mol_batch)) pred_hiddens = torch.cat(pred_hiddens, dim=0) pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0) pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1) pred_vecs = nn.ReLU()(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_targets = create_var(torch.LongTensor(pred_targets)) pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch) _, preds = torch.max(pred_scores, dim=1) pred_acc = torch.eq(preds, pred_targets).float() pred_acc = torch.sum(pred_acc) / pred_targets.nelement() stop_hiddens = torch.cat(stop_hiddens, dim=0) stop_vecs = nn.ReLU()(self.U(stop_hiddens)) stop_scores = self.U_s(stop_vecs).squeeze() stop_targets = create_var(torch.Tensor(stop_targets)) stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch) stops = torch.ge(stop_scores, 0).float() stop_acc = torch.eq(stops, stop_targets).float() stop_acc = torch.sum(stop_acc) / stop_targets.nelement() return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def soft_decode(self, x_tree_vecs, x_mol_vecs, gumbel, slope, temp): assert x_tree_vecs.size(0) == 1 soft_embedding = lambda x: x.matmul(self.embedding.weight) if gumbel: sample_softmax = lambda x: F.gumbel_softmax(x, tau=temp) else: sample_softmax = lambda x: F.softmax(x / temp, dim=1) stack = [] init_hiddens = create_var(torch.zeros(1, self.hidden_size)) zero_pad = create_var(torch.zeros(1, 1, self.hidden_size)) contexts = create_var(torch.LongTensor(1).zero_()) #Root Prediction root_score = self.attention(init_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'word') root_prob = sample_softmax(root_score) root = MolTreeNode("") root.embedding = soft_embedding(root_prob) root.prob = root_prob root.idx = 0 stack.append(root) all_nodes = [root] all_hiddens = [] h = {} for step in xrange(MAX_SOFT_DECODE_LEN): node_x = stack[-1] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad #Predict stop cur_x = node_x.embedding cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([cur_x, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'stop') all_hiddens.append(stop_hiddens) forward = 0 if stop_score.item() < 0 else 1 stop_prob = F.hardtanh(slope * stop_score + 0.5, min_val=0, max_val=1).unsqueeze(1) stop_val_ste = forward + stop_prob - stop_prob.detach() if forward == 1: #Forward: Predict next clique new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) pred_score = self.attention(new_h, contexts, x_tree_vecs, x_mol_vecs, 'word') pred_prob = sample_softmax(pred_score) node_y = MolTreeNode("") node_y.embedding = soft_embedding(pred_prob) node_y.prob = pred_prob node_y.idx = len(all_nodes) node_y.neighbors.append(node_x) h[(node_x.idx, node_y.idx)] = new_h[0] * stop_val_ste stack.append(node_y) all_nodes.append(node_y) else: if len(stack) == 1: #At root, terminate return torch.cat([cur_x, cur_h], dim=1), all_nodes node_fa = stack[-2] cur_h_nei = [ h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view( 1, -1, self.hidden_size) else: cur_h_nei = zero_pad new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h) h[(node_x.idx, node_fa.idx)] = new_h[0] * (1.0 - stop_val_ste) node_fa.neighbors.append(node_x) stack.pop() #Failure mode: decoding unfinished cur_h_nei = [h[(node_y.idx, root.idx)] for node_y in root.neighbors] if len(cur_h_nei) > 0: cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size) else: cur_h_nei = zero_pad cur_h = cur_h_nei.sum(dim=1) stop_hiddens = torch.cat([root.embedding, cur_h], dim=1) stop_hiddens = F.relu(self.U_i(stop_hiddens)) all_hiddens.append(stop_hiddens) return torch.cat([root.embedding, cur_h], dim=1), all_nodes