Exemple #1
0
    def decode(self, x_tree_vecs, prob_decode):
        assert x_tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        #Root Prediction
        root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs,
                                    'word')
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hiddens))
            stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs,
                                        'stop')

            if prob_decode:
                backtrack = (torch.bernoulli(
                    torch.sigmoid(stop_score)).item() == 0)
            else:
                backtrack = (stop_score.item() < 0)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_score = self.aggregate(new_h, contexts, x_tree_vecs,
                                            'word')

                if prob_decode:
                    sort_wid = torch.multinomial(
                        F.softmax(pred_score, dim=1).squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score,
                                             dim=1,
                                             descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(
                            node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = len(all_nodes)
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
Exemple #2
0
def get_subtree(tree, edge, x_node_vecs, x_mess_dict):
    subtree_list = {}
    node_tree_idx = {}
    node_list = {}
    # ========================= Get Subtree List ===============================
    
    tree.nodes[0].keep_neighbors = []
    for i, node in enumerate(tree.nodes[1:]):
        fa_node = node.fa_node
        node_idx = node.idx
        idx = x_mess_dict[node.fa_node.idx, node.idx]
        if not edge[idx]:
            if fa_node in node_tree_idx:
                new_node = MolTreeNode(node.smiles)
                new_node.wid = node.wid
                new_node.neighbors = [node.fa_node.cnode]
                new_node.idx = node_idx

                node.cnode = new_node
                node.fa_node.cnode.neighbors.append(new_node)
                node_tree_idx[node] = node_tree_idx[node.fa_node]

                tree_node = node_tree_idx[node.fa_node]
                subtree_list[tree_node].add_node(new_node)
            else:
                new_fa_node = MolTreeNode(node.fa_node.smiles)
                new_fa_node.wid = fa_node.wid
                new_fa_node.idx = fa_node.idx
                new_node = MolTreeNode(node.smiles)
                new_node.wid = node.wid
                new_node.idx = node_idx
                new_fa_node.neighbors = [new_node]
                new_node.neighbors = [new_fa_node]

                node.cnode = new_node
                node.fa_node.cnode = new_fa_node

                subtree_list[new_fa_node] = Subtree(new_fa_node)
                subtree_list[new_fa_node].add_node(new_node)

                node_tree_idx[fa_node] = new_fa_node
                node_tree_idx[node] = new_fa_node

                if node.fa_node.wid in node_list:
                    node_list[node.fa_node.wid].append((new_fa_node, new_fa_node))
                else:
                    node_list[node.fa_node.wid] = [(new_fa_node, new_fa_node)]

            fa_node = node_tree_idx[node]
            if node.wid in node_list:
                node_list[node.wid].append((fa_node, node))
            else:
                node_list[node.wid] = [(fa_node, node)]

    # ========================= Subtree Embedding ==============================
    max_idx, max_num = 0, 0
    if len(subtree_list) > 1:
        for idx in subtree_list:
            if len(subtree_list[idx].nodes) > max_num:
                max_num = len(subtree_list[idx].nodes)
                max_idx = idx

        max_subtree = subtree_list[max_idx]
    else:
        max_subtree = subtree_list[list(subtree_list.keys())[0]]

    for i, node in enumerate(max_subtree.nodes):
        node.idx = i
        node.nid = i

    return subtree_list, max_subtree, node_tree_idx, node_list
Exemple #3
0
    def forward(self, mol_batch, mol_vec):
        super_root = MolTreeNode("")
        super_root.idx = -1

        #Initialize
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], super_root)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        #Predict Root
        pred_hiddens.append(
            create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_mol_vecs.append(mol_vec)

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in xrange(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in prop_list:
                #Neighbors for message passing (target not included)
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != real_y.idx
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                #Neighbors for stop prediction (all neighbors)
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                #Current clique embedding
                cur_x.append(node_x.wid)

            #Clique embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)

            #Message passing
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            #Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            #Gather targets
            pred_target, pred_list = [], []
            stop_target = []
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            #Hidden states for stop prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            cur_mol_vec = mol_vec.index_select(0, cur_batch)
            stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            #Hidden states for clique prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        #Last stop at root
        cur_x, cur_o_nei = [], []
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend([0] * len(mol_batch))

        #Predict next clique
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = nn.ReLU()(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        #Predict stop
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = nn.ReLU()(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
Exemple #4
0
 def get_trace(self, node):
     super_root = MolTreeNode("")
     super_root.idx = -1
     trace = []
     dfs(trace, node, super_root)
     return [(x.smiles, y.smiles, z) for x, y, z in trace]
Exemple #5
0
    def decode(self, mol_vec, prob_decode):
        stack, trace = [], []
        init_hidden = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))

        #Root Prediction
        root_hidden = torch.cat([init_hidden, mol_vec], dim=1)
        root_hidden = nn.ReLU()(self.W(root_hidden))
        root_score = self.W_o(root_hidden)
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1)
            stop_hidden = nn.ReLU()(self.U(stop_hidden))
            stop_score = nn.Sigmoid()(self.U_s(stop_hidden) * 20).squeeze()

            if prob_decode:
                backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
            else:
                backtrack = (stop_score.item() < 0.5)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_hidden = torch.cat([new_h, mol_vec], dim=1)
                pred_hidden = nn.ReLU()(self.W(pred_hidden))
                pred_score = nn.Softmax(dim=1)(self.W_o(pred_hidden) * 20)
                if prob_decode:
                    sort_wid = torch.multinomial(pred_score.data.squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score,
                                             dim=1,
                                             descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(
                            node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = step + 1
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
    def forward(self, mol_batch, mol_vec):
        super_root = MolTreeNode('')
        super_root.idx = -1

        # 初始化
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], super_root)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        pred_hiddens.append(
            create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_mol_vecs.append(mol_vec)

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in range(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if len(plist) > t:
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in prop_list:
                # cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                cur_nei = []
                for node_y in node_x.neighbors:
                    if node_y.idx != real_y.idx:
                        ht = h[(node_y.idx, node_x.idx)]
                        print(ht)
                        cur_nei.append(ht)
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                cur_nei = [
                    h[node_y.idx, node_x.idx] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                cur_x.append(node_x.wid)

            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)
            print(len(cur_h_nei))
            print(cur_h_nei[0].shape)
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            print(cur_x.shape)
            print(cur_h_nei.shape)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            pred_target, pred_list = [], []
            stop_target = []
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            cur_batch = create_var(torch.LongTensor(batch_list))
            cur_mol_vec = mol_vec.index_select(0, cur_batch)
            stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        cur_x, cur_o_nei = [], []
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend([0] * len(mol_batch))

        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = nn.ReLU()(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = nn.ReLU()(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
Exemple #7
0
    def soft_decode(self, x_tree_vecs, x_mol_vecs, gumbel, slope, temp):
        assert x_tree_vecs.size(0) == 1

        soft_embedding = lambda x: x.matmul(self.embedding.weight)
        if gumbel:
            sample_softmax = lambda x: F.gumbel_softmax(x, tau=temp)
        else:
            sample_softmax = lambda x: F.softmax(x / temp, dim=1)

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        #Root Prediction
        root_score = self.attention(init_hiddens, contexts, x_tree_vecs,
                                    x_mol_vecs, 'word')
        root_prob = sample_softmax(root_score)

        root = MolTreeNode("")
        root.embedding = soft_embedding(root_prob)
        root.prob = root_prob
        root.idx = 0
        stack.append(root)

        all_nodes = [root]
        all_hiddens = []
        h = {}
        for step in xrange(MAX_SOFT_DECODE_LEN):
            node_x = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            #Predict stop
            cur_x = node_x.embedding
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hiddens))
            stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs,
                                        x_mol_vecs, 'stop')
            all_hiddens.append(stop_hiddens)

            forward = 0 if stop_score.item() < 0 else 1
            stop_prob = F.hardtanh(slope * stop_score + 0.5,
                                   min_val=0,
                                   max_val=1).unsqueeze(1)
            stop_val_ste = forward + stop_prob - stop_prob.detach()

            if forward == 1:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_score = self.attention(new_h, contexts, x_tree_vecs,
                                            x_mol_vecs, 'word')
                pred_prob = sample_softmax(pred_score)

                node_y = MolTreeNode("")
                node_y.embedding = soft_embedding(pred_prob)
                node_y.prob = pred_prob
                node_y.idx = len(all_nodes)
                node_y.neighbors.append(node_x)

                h[(node_x.idx, node_y.idx)] = new_h[0] * stop_val_ste
                stack.append(node_y)
                all_nodes.append(node_y)
            else:
                if len(stack) == 1:  #At root, terminate
                    return torch.cat([cur_x, cur_h], dim=1), all_nodes

                node_fa = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0] * (1.0 - stop_val_ste)
                node_fa.neighbors.append(node_x)
                stack.pop()

        #Failure mode: decoding unfinished
        cur_h_nei = [h[(node_y.idx, root.idx)] for node_y in root.neighbors]
        if len(cur_h_nei) > 0:
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(1, -1, self.hidden_size)
        else:
            cur_h_nei = zero_pad
        cur_h = cur_h_nei.sum(dim=1)

        stop_hiddens = torch.cat([root.embedding, cur_h], dim=1)
        stop_hiddens = F.relu(self.U_i(stop_hiddens))
        all_hiddens.append(stop_hiddens)

        return torch.cat([root.embedding, cur_h], dim=1), all_nodes