コード例 #1
0
ファイル: jtnn_dec.py プロジェクト: alexjli/addcov19
    def decode(self, x_tree_vecs, prob_decode):
        assert x_tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        #Root Prediction
        root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs,
                                    'word')
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hiddens))
            stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs,
                                        'stop')

            if prob_decode:
                backtrack = (torch.bernoulli(
                    torch.sigmoid(stop_score)).item() == 0)
            else:
                backtrack = (stop_score.item() < 0)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_score = self.aggregate(new_h, contexts, x_tree_vecs,
                                            'word')

                if prob_decode:
                    sort_wid = torch.multinomial(
                        F.softmax(pred_score, dim=1).squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score,
                                             dim=1,
                                             descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(
                            node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = len(all_nodes)
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
コード例 #2
0
def get_subtree(tree, edge, x_node_vecs, x_mess_dict):
    subtree_list = {}
    node_tree_idx = {}
    node_list = {}
    # ========================= Get Subtree List ===============================
    
    tree.nodes[0].keep_neighbors = []
    for i, node in enumerate(tree.nodes[1:]):
        fa_node = node.fa_node
        node_idx = node.idx
        idx = x_mess_dict[node.fa_node.idx, node.idx]
        if not edge[idx]:
            if fa_node in node_tree_idx:
                new_node = MolTreeNode(node.smiles)
                new_node.wid = node.wid
                new_node.neighbors = [node.fa_node.cnode]
                new_node.idx = node_idx

                node.cnode = new_node
                node.fa_node.cnode.neighbors.append(new_node)
                node_tree_idx[node] = node_tree_idx[node.fa_node]

                tree_node = node_tree_idx[node.fa_node]
                subtree_list[tree_node].add_node(new_node)
            else:
                new_fa_node = MolTreeNode(node.fa_node.smiles)
                new_fa_node.wid = fa_node.wid
                new_fa_node.idx = fa_node.idx
                new_node = MolTreeNode(node.smiles)
                new_node.wid = node.wid
                new_node.idx = node_idx
                new_fa_node.neighbors = [new_node]
                new_node.neighbors = [new_fa_node]

                node.cnode = new_node
                node.fa_node.cnode = new_fa_node

                subtree_list[new_fa_node] = Subtree(new_fa_node)
                subtree_list[new_fa_node].add_node(new_node)

                node_tree_idx[fa_node] = new_fa_node
                node_tree_idx[node] = new_fa_node

                if node.fa_node.wid in node_list:
                    node_list[node.fa_node.wid].append((new_fa_node, new_fa_node))
                else:
                    node_list[node.fa_node.wid] = [(new_fa_node, new_fa_node)]

            fa_node = node_tree_idx[node]
            if node.wid in node_list:
                node_list[node.wid].append((fa_node, node))
            else:
                node_list[node.wid] = [(fa_node, node)]

    # ========================= Subtree Embedding ==============================
    max_idx, max_num = 0, 0
    if len(subtree_list) > 1:
        for idx in subtree_list:
            if len(subtree_list[idx].nodes) > max_num:
                max_num = len(subtree_list[idx].nodes)
                max_idx = idx

        max_subtree = subtree_list[max_idx]
    else:
        max_subtree = subtree_list[list(subtree_list.keys())[0]]

    for i, node in enumerate(max_subtree.nodes):
        node.idx = i
        node.nid = i

    return subtree_list, max_subtree, node_tree_idx, node_list
コード例 #3
0
    def decode(self, mol_vec, prob_decode):
        stack, trace = [], []
        init_hidden = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))

        #Root Prediction
        root_hidden = torch.cat([init_hidden, mol_vec], dim=1)
        root_hidden = nn.ReLU()(self.W(root_hidden))
        root_score = self.W_o(root_hidden)
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1)
            stop_hidden = nn.ReLU()(self.U(stop_hidden))
            stop_score = nn.Sigmoid()(self.U_s(stop_hidden) * 20).squeeze()

            if prob_decode:
                backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
            else:
                backtrack = (stop_score.item() < 0.5)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_hidden = torch.cat([new_h, mol_vec], dim=1)
                pred_hidden = nn.ReLU()(self.W(pred_hidden))
                pred_score = nn.Softmax(dim=1)(self.W_o(pred_hidden) * 20)
                if prob_decode:
                    sort_wid = torch.multinomial(pred_score.data.squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score,
                                             dim=1,
                                             descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(
                            node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = step + 1
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes