예제 #1
0
    def forward(self, root_batch):
        orders = []  # orders: list(list), 每个子列表代表一个根层次遍历的结果
        for root in root_batch:
            # oder: list(list), 一个列表分为两部分,
            # 一个是自底向上的顺序,每个子列表中包含该层的结点及其父节点
            # 一个是自顶向下的顺序,每个子列表中包含该层的结点及其子节点
            order = get_prop_order(root)
            orders.append(order)

        h = {}
        max_depth = max([len(order) for order in orders])
        padding = create_var(torch.zeros(self.hidden_size), False)

        for t in range(max_depth):
            prop_list = []
            for order in orders:
                if len(order) > t:  # 确保这棵树有第t层
                    prop_list.extend(order[t])  # 第t层的层次列表加入到prop_list

            cur_x = []
            cur_h_nei = []
            for node_x, node_y in prop_list:
                x, y = node_x.idx, node_y.idx  # 结点编号
                cur_x.append(node_x.wid)  # 结点类型编号

                h_nei = []
                for node_z in node_x.neighbors:
                    z = node_z.idx
                    if z == y:
                        continue
                    # h_nei:结点x除y以外的邻居,即与其相邻的结点
                    h_nei.append(h[(z, x)])

                # 如果邻居数量达不到最大值,则用padding的变量填充
                pad_len = MAX_NB - len(h_nei)
                h_nei.extend([padding] * pad_len)
                cur_h_nei.extend(h_nei)

            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(
                cur_x
            )  # 从这里开始,标签转化为了向量, cur_x.size = (len(prop_list), hidden_size)
            cur_h_nei = torch.cat(cur_h_nei,
                                  dim=0).view(-1, MAX_NB, self.hidden_size)
            # cur_nei_h.size = (len(prop_list), MAX_NB, hidden_size)

            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)
            for i, m in enumerate(prop_list):
                x, y = m[0].idx, m[1].idx
                h[(x, y)] = new_h[i]

        # node aggregate
        root_vecs = node_aggregate(root_batch, h, self.embedding, self.W)
        return h, root_vecs
예제 #2
0
    def forward(self, root_batch):
        orders = []
        for root in root_batch:
            order = get_prop_order(root)
            orders.append(order)
        
        h = {}
        max_depth = max([len(x) for x in orders])
        padding = create_var(torch.zeros(self.hidden_size), False)

        maxx=0

        for t in range(max_depth):
            prop_list = []
            for order in orders:
                if t < len(order):
                    prop_list.extend(order[t])

            cur_x = []
            cur_h_nei = []
            for node_x,node_y in prop_list:
                x,y = node_x.idx,node_y.idx
                cur_x.append(node_x.wid)

                h_nei = []
                for node_z in node_x.neighbors:
                    z = node_z.idx
                    if z == y: continue
                    h_nei.append(h[(z,x)])
                if len(h_nei)>MAX_NB:
                    print("len(h_nei)")
                    print(len(h_nei))

                if len(h_nei)>maxx:
                    maxx=len(h_nei)
                pad_len = MAX_NB - len(h_nei)
                h_nei.extend([padding] * pad_len)
                cur_h_nei.extend(h_nei)
            #print(maxx)
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)
            #print(torch.cat(cur_h_nei, dim=0).size())
            cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)

            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
            for i,m in enumerate(prop_list):
                x,y = m[0].idx,m[1].idx
                h[(x,y)] = new_h[i]

        root_vecs = node_aggregate(root_batch, h, self.embedding, self.W)

        return h, root_vecs
예제 #3
0
    def forward(self, root_batch):
        orders = []
        for root in root_batch:
            order = get_prop_order(root)
            orders.append(order)

        h = {}
        max_depth = max([len(x) for x in orders])
        padding = torch.zeros(self.hidden_size).to(self.device)

        for t in xrange(max_depth):
            prop_list = []
            for order in orders:
                if t < len(order):
                    prop_list.extend(order[t])

            cur_x = []
            cur_h_nei = []
            for node_x, node_y in prop_list:
                x, y = node_x.idx, node_y.idx
                cur_x.append(node_x.wid)

                h_nei = []
                for node_z in node_x.neighbors:
                    z = node_z.idx
                    if z == y: continue
                    h_nei.append(h[(z, x)])

                pad_len = MAX_NB - len(h_nei)
                h_nei.extend([padding] * pad_len)
                cur_h_nei.extend(h_nei)

            cur_x = torch.tensor(cur_x, dtype=torch.long).to(self.device)
            cur_x = self.embedding(cur_x)
            cur_h_nei = torch.cat(cur_h_nei,
                                  dim=0).view(-1, MAX_NB, self.hidden_size)

            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h, self.device)
            for i, m in enumerate(prop_list):
                x, y = m[0].idx, m[1].idx
                h[(x, y)] = new_h[i]

        root_vecs = node_aggregate(root_batch, h, self.embedding, self.W,
                                   self.device)

        return h, root_vecs
예제 #4
0
    def forward(self, holder, depth):
        fnode = create_var(holder[0])
        fmess = create_var(holder[1])
        node_graph = create_var(holder[2])
        mess_graph = create_var(holder[3])
        scope = holder[4]

        fnode = self.embedding(fnode)
        x = index_select_ND(fnode, 0, fmess)
        h = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))

        mask = torch.ones(h.size(0), 1)
        mask[0] = 0  #first vector is padding
        mask = create_var(mask)

        for it in xrange(depth):
            h_nei = index_select_ND(h, 0, mess_graph)
            h = GRU(x, h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
            h = h * mask

        mess_nei = index_select_ND(h, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        root_vecs = [node_vecs[st] for st, le in scope]
        return torch.stack(root_vecs, dim=0)
예제 #5
0
    def forward(self, mol_batch, x_tree_vecs):
        pred_hiddens, pred_contexts, pred_targets = [], [], []
        stop_hiddens, stop_contexts, stop_targets = [], [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], -1)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        #Predict Root
        batch_size = len(mol_batch)
        pred_hiddens.append(
            create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_contexts.append(create_var(torch.LongTensor(range(batch_size))))

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in xrange(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in prop_list:
                #Neighbors for message passing (target not included)
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != real_y.idx
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                #Neighbors for stop prediction (all neighbors)
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                #Current clique embedding
                cur_x.append(node_x.wid)

            #Clique embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)

            #Message passing
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            #Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            #Gather targets
            pred_target, pred_list = [], []
            stop_target = []
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            #Hidden states for stop prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x, cur_o], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_contexts.append(cur_batch)
            stop_targets.extend(stop_target)

            #Hidden states for clique prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append(cur_batch)

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        #Last stop at root
        cur_x, cur_o_nei = [], []
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_contexts.append(create_var(torch.LongTensor(range(batch_size))))
        stop_targets.extend([0] * len(mol_batch))

        #Predict next clique
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs,
                                     'word')
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        #Predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu(self.U_i(stop_hiddens))
        stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs,
                                     'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
예제 #6
0
    def decode(self, x_tree_vecs, prob_decode):
        assert x_tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        #Root Prediction
        root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs,
                                    'word')
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hiddens))
            stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs,
                                        'stop')

            if prob_decode:
                backtrack = (torch.bernoulli(
                    torch.sigmoid(stop_score)).item() == 0)
            else:
                backtrack = (stop_score.item() < 0)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_score = self.aggregate(new_h, contexts, x_tree_vecs,
                                            'word')

                if prob_decode:
                    sort_wid = torch.multinomial(
                        F.softmax(pred_score, dim=1).squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score,
                                             dim=1,
                                             descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(
                            node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = len(all_nodes)
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
예제 #7
0
    def decode(self, mol_vec, prob_decode):
        stack, trace = [], []
        init_hidden = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))

        #Root Prediction
        root_hidden = torch.cat([init_hidden, mol_vec], dim=1)
        root_hidden = nn.ReLU()(self.W(root_hidden))
        root_score = self.W_o(root_hidden)
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1)
            stop_hidden = nn.ReLU()(self.U(stop_hidden))
            stop_score = nn.Sigmoid()(self.U_s(stop_hidden) * 20).squeeze()

            if prob_decode:
                backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
            else:
                backtrack = (stop_score.item() < 0.5)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_hidden = torch.cat([new_h, mol_vec], dim=1)
                pred_hidden = nn.ReLU()(self.W(pred_hidden))
                pred_score = nn.Softmax(dim=1)(self.W_o(pred_hidden) * 20)
                if prob_decode:
                    sort_wid = torch.multinomial(pred_score.data.squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score,
                                             dim=1,
                                             descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(
                            node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = step + 1
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
예제 #8
0
    def forward(self, mol_batch, mol_vec):
        super_root = MolTreeNode('')
        super_root.idx = -1

        # 初始化
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], super_root)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        pred_hiddens.append(
            create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_mol_vecs.append(mol_vec)

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in range(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if len(plist) > t:
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in prop_list:
                # cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                cur_nei = []
                for node_y in node_x.neighbors:
                    if node_y.idx != real_y.idx:
                        ht = h[(node_y.idx, node_x.idx)]
                        print(ht)
                        cur_nei.append(ht)
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                cur_nei = [
                    h[node_y.idx, node_x.idx] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                cur_x.append(node_x.wid)

            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)
            print(len(cur_h_nei))
            print(cur_h_nei[0].shape)
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            print(cur_x.shape)
            print(cur_h_nei.shape)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            pred_target, pred_list = [], []
            stop_target = []
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            cur_batch = create_var(torch.LongTensor(batch_list))
            cur_mol_vec = mol_vec.index_select(0, cur_batch)
            stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        cur_x, cur_o_nei = [], []
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend([0] * len(mol_batch))

        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = nn.ReLU()(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = nn.ReLU()(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
예제 #9
0
파일: decoder.py 프로젝트: SNWK/DivideTree
    def decode(self, x_tree_vecs, prob_decode):
        assert x_tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        #Root Prediction
        root_feature = self.aggregate(init_hiddens, contexts, x_tree_vecs,
                                      'word')
        # _,root_wid = torch.max(root_score, dim=1)
        # root_wid = root_wid.item()

        root = TreeNode(root_feature)
        root.idx = 0
        root.graphid = 0
        stack.append((root, root_feature))

        all_nodes = [root]
        h = {}
        for step in xrange(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.graphid, node_x.graphid)]
                for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad
            # todo
            cur_x = create_var(torch.LongTensor([node_x.feature]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hiddens))
            stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs,
                                        'stop')

            if prob_decode:
                backtrack = (torch.bernoulli(
                    torch.sigmoid(stop_score)).item() == 0)
            else:
                backtrack = (stop_score.item() < 0)

            if not backtrack:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_feature = self.aggregate(new_h, contexts, x_tree_vecs,
                                              'word')

                node_y = TreeNode(pred_feature)
                node_y.idx = len(all_nodes)
                node_y.graphid = len(all_nodes)
                node_y.neighbors.append(node_x)
                h[(node_x.graphid, node_y.graphid)] = new_h[0]
                stack.append((node_y, pred_feature))
                all_nodes.append(node_y)

            if backtrack:  #Backtrack, use if instead of else
                if len(stack) == 1:
                    break  #At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.graphid, node_x.graphid)]
                    for node_y in node_x.neighbors if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.graphid, node_fa.graphid)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
    def decode(self, tree_vecs):
        """
        Description: Given the tree vector, predict the corresponding junction-tree.

        Args:
            tree_vecs: torch.tensor (shape: hidden_size)

        Returns:
            root: MolJuncTreeNode
                The root node of the decoded junction-tree.

            all_nodes: List[MolJuncTreeNode]
                The list of all the nodes in the decoded junction-tree.
        """
        assert tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        # root Prediction
        root_score = self.aggregate(init_hiddens, contexts, tree_vecs, 'word')
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolJuncTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append((root, self.vocab.get_slots(root.wid)))

        all_nodes = [root]
        h = {}
        for step in range(MAX_DECODE_LEN):
            node_x, parent_slot = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.vocab_embedding(cur_x)

            # Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hidden = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hidden))
            stop_score = self.aggregate(stop_hiddens, contexts, tree_vecs,
                                        'stop')

            backtrack = (stop_score.item() < 0)

            # go down forward: predict next cluster
            if not backtrack:
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_score = self.aggregate(new_h, contexts, tree_vecs, 'word')

                _, sort_wid = torch.sort(pred_score, dim=1, descending=True)
                sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolJuncTreeNode(self.vocab.get_smiles(wid))
                    if self.have_slots(parent_slot,
                                       slots) and self.can_assemble(
                                           node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                # no more children can be added
                if next_wid is None:
                    backtrack = True
                else:
                    node_y = MolJuncTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = len(all_nodes)
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx, node_y.idx)] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            # backtrack, use if instead of else
            if backtrack:
                if len(stack) == 1:
                    # back to root, terminate
                    break

                parent_node, _ = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != parent_node.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, parent_node.idx)] = new_h[0]
                parent_node.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes
    def forward(self, junc_tree_batch, tree_vecs):
        """
        Args:
            junc_tree_batch: List[MolJuncTree]
                The list of junction-trees for all the molecules, across the entire dataset.

            tree_vecs: torch.tensor (shape: batch_size x hidden_size)
                The vector represenations of all the junction-trees, for all the molecules, across the entire dataset.
        """
        # initialize
        pred_hiddens, pred_contexts, pred_targets = [], [], []
        stop_hiddens, stop_contexts, stop_targets = [], [], []

        # list to store dfs traversals for molecular trees of all molecules
        traces = []

        for junc_tree in junc_tree_batch:
            stack = []
            # root node has no parent node,
            # so we use a virtual node with idx = -1
            self.dfs(stack, junc_tree.nodes[0], -1)
            traces.append(stack)

            for node in junc_tree.nodes:
                node.neighbors = []

        # predict root
        batch_size = len(junc_tree_batch)
        pred_hiddens.append(
            create_var(torch.zeros(batch_size, self.hidden_size)))

        # list of indices of cluster vocabulary items,
        # for the root node of all junction trees, across the entire dataset.
        pred_targets.extend(
            [junc_tree.nodes[0].wid for junc_tree in junc_tree_batch])

        pred_contexts.append(create_var(torch.LongTensor(range(batch_size))))

        # number of traversals to go through, to ensure that dfs traversal is completed for the
        # junction-tree with the largest size / height.
        max_iter = max([len(tr) for tr in traces])

        # padding vector for putting in place of messages from non-existant neighbors
        padding = create_var(torch.zeros(self.hidden_size), False)

        # dictionary to store hidden edge message vectors
        h = {}

        for iter in range(max_iter):
            # list to store edge tuples that will be considered in this iteration.
            edge_tuple_list = []

            # batch id of all junc_trees / tree_vecs whose edge_tuple
            # in being considered in this timestep
            batch_list = []

            for idx, dfs_traversal in enumerate(traces):
                # keep appending traversal orders for a particular depth level,
                # from a given traversal_order list,
                # until the list is not empty
                if iter < len(dfs_traversal):
                    edge_tuple_list.append(dfs_traversal[iter])
                    batch_list.append(idx)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in edge_tuple_list:
                # neighbors for message passing (target not included)
                # hidden edge message vectors from predecessor neighbor nodes
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != real_y.idx
                ]
                # a node can at max MAX_NUM_NEIGHBORS(=15) neighbors
                # if it has less neighbors, then we append vector of zeros as messages from non-existent neighbors
                pad_len = MAX_NUM_NEIGHBORS - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                # neighbors for stop (topological) prediction (all neighbors)
                # hidden edge messages from all neighbor nodes
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NUM_NEIGHBORS - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                # current cluster embedding
                cur_x.append(node_x.wid)

            # cluster embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.vocab_embedding(cur_x)

            # implement message passing
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NUM_NEIGHBORS,
                                                self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            # node aggregate
            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NUM_NEIGHBORS,
                                                self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            # gather targets
            pred_target, pred_list = [], []
            stop_target = []

            # teacher forcing
            for idx, edge_tuple in enumerate(edge_tuple_list):
                node_x, node_y, direction = edge_tuple
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[idx]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(idx)
                stop_target.append(direction)

            # hidden states for stop (topological) prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x, cur_o], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_contexts.append(cur_batch)
            stop_targets.extend(stop_target)

            # hidden states for cluster prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[idx] for idx in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append(cur_batch)

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        # last stop at root
        cur_x, cur_o_nei = [], []
        for junc_tree in junc_tree_batch:
            node_x = junc_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NUM_NEIGHBORS - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.vocab_embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NUM_NEIGHBORS,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_contexts.append(create_var(torch.LongTensor(range(batch_size))))
        stop_targets.extend([0] * len(junc_tree_batch))

        # predict next cluster
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_scores = self.aggregate(pred_hiddens, pred_contexts, tree_vecs,
                                     'word')
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores,
                                   pred_targets) / len(junc_tree_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        # predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu(self.U_i(stop_hiddens))
        stop_scores = self.aggregate(stop_hiddens, stop_contexts, tree_vecs,
                                     'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores,
                                   stop_targets) / len(junc_tree_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
예제 #12
0
    def soft_decode(self, x_tree_vecs, x_mol_vecs, gumbel, slope, temp):
        assert x_tree_vecs.size(0) == 1

        soft_embedding = lambda x: x.matmul(self.embedding.weight)
        if gumbel:
            sample_softmax = lambda x: F.gumbel_softmax(x, tau=temp)
        else:
            sample_softmax = lambda x: F.softmax(x / temp, dim=1)

        stack = []
        init_hiddens = create_var(torch.zeros(1, self.hidden_size))
        zero_pad = create_var(torch.zeros(1, 1, self.hidden_size))
        contexts = create_var(torch.LongTensor(1).zero_())

        #Root Prediction
        root_score = self.attention(init_hiddens, contexts, x_tree_vecs,
                                    x_mol_vecs, 'word')
        root_prob = sample_softmax(root_score)

        root = MolTreeNode("")
        root.embedding = soft_embedding(root_prob)
        root.prob = root_prob
        root.idx = 0
        stack.append(root)

        all_nodes = [root]
        all_hiddens = []
        h = {}
        for step in xrange(MAX_SOFT_DECODE_LEN):
            node_x = stack[-1]
            cur_h_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei,
                                        dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            #Predict stop
            cur_x = node_x.embedding
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x, cur_h], dim=1)
            stop_hiddens = F.relu(self.U_i(stop_hiddens))
            stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs,
                                        x_mol_vecs, 'stop')
            all_hiddens.append(stop_hiddens)

            forward = 0 if stop_score.item() < 0 else 1
            stop_prob = F.hardtanh(slope * stop_score + 0.5,
                                   min_val=0,
                                   max_val=1).unsqueeze(1)
            stop_val_ste = forward + stop_prob - stop_prob.detach()

            if forward == 1:  #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                pred_score = self.attention(new_h, contexts, x_tree_vecs,
                                            x_mol_vecs, 'word')
                pred_prob = sample_softmax(pred_score)

                node_y = MolTreeNode("")
                node_y.embedding = soft_embedding(pred_prob)
                node_y.prob = pred_prob
                node_y.idx = len(all_nodes)
                node_y.neighbors.append(node_x)

                h[(node_x.idx, node_y.idx)] = new_h[0] * stop_val_ste
                stack.append(node_y)
                all_nodes.append(node_y)
            else:
                if len(stack) == 1:  #At root, terminate
                    return torch.cat([cur_x, cur_h], dim=1), all_nodes

                node_fa = stack[-2]
                cur_h_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != node_fa.idx
                ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(
                        1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                            self.W_h)
                h[(node_x.idx, node_fa.idx)] = new_h[0] * (1.0 - stop_val_ste)
                node_fa.neighbors.append(node_x)
                stack.pop()

        #Failure mode: decoding unfinished
        cur_h_nei = [h[(node_y.idx, root.idx)] for node_y in root.neighbors]
        if len(cur_h_nei) > 0:
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(1, -1, self.hidden_size)
        else:
            cur_h_nei = zero_pad
        cur_h = cur_h_nei.sum(dim=1)

        stop_hiddens = torch.cat([root.embedding, cur_h], dim=1)
        stop_hiddens = F.relu(self.U_i(stop_hiddens))
        all_hiddens.append(stop_hiddens)

        return torch.cat([root.embedding, cur_h], dim=1), all_nodes
예제 #13
0
파일: jtnn_dec.py 프로젝트: futianfan/CORE
    def forward(self, mol_batch, x_tree_vecs, x_mol_vecs, origin_word):
        """
            mol_batch: Y;  where (X,Y), X is input, Y is target.  
        """
        pred_hiddens,pred_contexts,pred_targets = [],[],[]
        stop_hiddens,stop_contexts,stop_targets = [],[],[]
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], -1)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        #Predict Root
        batch_size = len(mol_batch)
        pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False) ### 0 
        h = {}

        for t in xrange(max_iter):
            prop_list = []
            batch_list = []
            for i,plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei,cur_o_nei = [],[]

            for node_x, real_y, _ in prop_list:
                #Neighbors for message passing (target not included)
                cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                #Neighbors for stop prediction (all neighbors)
                cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                #Current clique embedding
                cur_x.append(node_x.wid)

            #Clique embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)

            #Message passing
            cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)

            #Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            #Gather targets
            pred_target,pred_list = [],[]
            stop_target = []
            for i,m in enumerate(prop_list):
                node_x,node_y,direction = m
                x,y = node_x.idx,node_y.idx
                h[(x,y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i) 
                stop_target.append(direction)

            #Hidden states for stop prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x,cur_o], dim=1)
            stop_hiddens.append( stop_hidden )
            stop_contexts.append( cur_batch )
            stop_targets.extend( stop_target )
            
            #Hidden states for clique prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append( cur_batch )

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append( new_h.index_select(0, cur_pred) )
                pred_targets.extend( pred_target )

        #Last stop at root
        cur_x,cur_o_nei = [],[]
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x,cur_o], dim=1)
        stop_hiddens.append( stop_hidden )
        stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
        stop_targets.extend( [0] * len(mol_batch) )


        #Predict next clique
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        #pred_scores = self.attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, 'word')
        pred_scores = self.mixture_attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, origin_word)
        if torch.isnan(pred_scores).any():
            print "forward nan"

        '''
        pickle.dump((pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, pred_scores, pred_targets, origin_word),\
         open("tmp.pkl", "wb"))
        print "save ok"
        exit()'''
        
        '''
            pred_hiddens: (420, 300)  420 target
            pred_contexts: (420,) int  
            x_tree_vecs: (32, 21, 300)
            x_mol_vecs: (32, 28, 300)
            pred_scores: (420, 780)
            pred_targets: [517, 517, 516, 516, 523, 517, 517, 517, 517, 477, ...]   len()=420; 
            origin_word:  input 
import pickle
pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, pred_scores, pred_targets, origin_word = pickle.load(open("tmp.pkl", "rb"))
        '''

        pred_targets = create_var(torch.LongTensor(pred_targets))

        #pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        pred_loss = self.nllloss(torch.log(
                                      torch.max(pred_scores,\
                                        torch.ones_like(pred_scores) * minimum_threshold_before_log) \
                                          ),\
                                 pred_targets) / len(mol_batch)
        _,preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        #Predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu( self.U_i(stop_hiddens) )
        stop_scores = self.attention(stop_hiddens, stop_contexts, x_tree_vecs, x_mol_vecs, 'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))
        
        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()