예제 #1
0
    def forward(self, mol_batch, x_tree_vecs):
        pred_hiddens, pred_contexts, pred_targets = [], [], []
        stop_hiddens, stop_contexts, stop_targets = [], [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], -1)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        #Predict Root
        batch_size = len(mol_batch)
        pred_hiddens.append(
            create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_contexts.append(create_var(torch.LongTensor(range(batch_size))))

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in xrange(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in prop_list:
                #Neighbors for message passing (target not included)
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != real_y.idx
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                #Neighbors for stop prediction (all neighbors)
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                #Current clique embedding
                cur_x.append(node_x.wid)

            #Clique embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)

            #Message passing
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            #Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            #Gather targets
            pred_target, pred_list = [], []
            stop_target = []
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            #Hidden states for stop prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x, cur_o], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_contexts.append(cur_batch)
            stop_targets.extend(stop_target)

            #Hidden states for clique prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append(cur_batch)

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        #Last stop at root
        cur_x, cur_o_nei = [], []
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_contexts.append(create_var(torch.LongTensor(range(batch_size))))
        stop_targets.extend([0] * len(mol_batch))

        #Predict next clique
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs,
                                     'word')
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        #Predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu(self.U_i(stop_hiddens))
        stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs,
                                     'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
    def forward(self, junc_tree_batch, tree_vecs):
        """
        Args:
            junc_tree_batch: List[MolJuncTree]
                The list of junction-trees for all the molecules, across the entire dataset.

            tree_vecs: torch.tensor (shape: batch_size x hidden_size)
                The vector represenations of all the junction-trees, for all the molecules, across the entire dataset.
        """
        # initialize
        pred_hiddens, pred_contexts, pred_targets = [], [], []
        stop_hiddens, stop_contexts, stop_targets = [], [], []

        # list to store dfs traversals for molecular trees of all molecules
        traces = []

        for junc_tree in junc_tree_batch:
            stack = []
            # root node has no parent node,
            # so we use a virtual node with idx = -1
            self.dfs(stack, junc_tree.nodes[0], -1)
            traces.append(stack)

            for node in junc_tree.nodes:
                node.neighbors = []

        # predict root
        batch_size = len(junc_tree_batch)
        pred_hiddens.append(
            create_var(torch.zeros(batch_size, self.hidden_size)))

        # list of indices of cluster vocabulary items,
        # for the root node of all junction trees, across the entire dataset.
        pred_targets.extend(
            [junc_tree.nodes[0].wid for junc_tree in junc_tree_batch])

        pred_contexts.append(create_var(torch.LongTensor(range(batch_size))))

        # number of traversals to go through, to ensure that dfs traversal is completed for the
        # junction-tree with the largest size / height.
        max_iter = max([len(tr) for tr in traces])

        # padding vector for putting in place of messages from non-existant neighbors
        padding = create_var(torch.zeros(self.hidden_size), False)

        # dictionary to store hidden edge message vectors
        h = {}

        for iter in range(max_iter):
            # list to store edge tuples that will be considered in this iteration.
            edge_tuple_list = []

            # batch id of all junc_trees / tree_vecs whose edge_tuple
            # in being considered in this timestep
            batch_list = []

            for idx, dfs_traversal in enumerate(traces):
                # keep appending traversal orders for a particular depth level,
                # from a given traversal_order list,
                # until the list is not empty
                if iter < len(dfs_traversal):
                    edge_tuple_list.append(dfs_traversal[iter])
                    batch_list.append(idx)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in edge_tuple_list:
                # neighbors for message passing (target not included)
                # hidden edge message vectors from predecessor neighbor nodes
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                    if node_y.idx != real_y.idx
                ]
                # a node can at max MAX_NUM_NEIGHBORS(=15) neighbors
                # if it has less neighbors, then we append vector of zeros as messages from non-existent neighbors
                pad_len = MAX_NUM_NEIGHBORS - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                # neighbors for stop (topological) prediction (all neighbors)
                # hidden edge messages from all neighbor nodes
                cur_nei = [
                    h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NUM_NEIGHBORS - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                # current cluster embedding
                cur_x.append(node_x.wid)

            # cluster embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.vocab_embedding(cur_x)

            # implement message passing
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NUM_NEIGHBORS,
                                                self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            # node aggregate
            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NUM_NEIGHBORS,
                                                self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            # gather targets
            pred_target, pred_list = [], []
            stop_target = []

            # teacher forcing
            for idx, edge_tuple in enumerate(edge_tuple_list):
                node_x, node_y, direction = edge_tuple
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[idx]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(idx)
                stop_target.append(direction)

            # hidden states for stop (topological) prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x, cur_o], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_contexts.append(cur_batch)
            stop_targets.extend(stop_target)

            # hidden states for cluster prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[idx] for idx in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append(cur_batch)

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        # last stop at root
        cur_x, cur_o_nei = [], []
        for junc_tree in junc_tree_batch:
            node_x = junc_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NUM_NEIGHBORS - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.vocab_embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NUM_NEIGHBORS,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_contexts.append(create_var(torch.LongTensor(range(batch_size))))
        stop_targets.extend([0] * len(junc_tree_batch))

        # predict next cluster
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_scores = self.aggregate(pred_hiddens, pred_contexts, tree_vecs,
                                     'word')
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores,
                                   pred_targets) / len(junc_tree_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        # predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu(self.U_i(stop_hiddens))
        stop_scores = self.aggregate(stop_hiddens, stop_contexts, tree_vecs,
                                     'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores,
                                   stop_targets) / len(junc_tree_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
예제 #3
0
    def forward(self, mol_batch, mol_vec):
        super_root = MolTreeNode('')
        super_root.idx = -1

        # 初始化
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], super_root)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        pred_hiddens.append(
            create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_mol_vecs.append(mol_vec)

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in range(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if len(plist) > t:
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei, cur_o_nei = [], []

            for node_x, real_y, _ in prop_list:
                # cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                cur_nei = []
                for node_y in node_x.neighbors:
                    if node_y.idx != real_y.idx:
                        ht = h[(node_y.idx, node_x.idx)]
                        print(ht)
                        cur_nei.append(ht)
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                cur_nei = [
                    h[node_y.idx, node_x.idx] for node_y in node_x.neighbors
                ]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                cur_x.append(node_x.wid)

            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)
            print(len(cur_h_nei))
            print(cur_h_nei[0].shape)
            cur_h_nei = torch.stack(cur_h_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            print(cur_x.shape)
            print(cur_h_nei.shape)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)

            cur_o_nei = torch.stack(cur_o_nei,
                                    dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            pred_target, pred_list = [], []
            stop_target = []
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            cur_batch = create_var(torch.LongTensor(batch_list))
            cur_mol_vec = mol_vec.index_select(0, cur_batch)
            stop_hidden = torch.cat([cur_x, cur_o, cur_mol_vec], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        cur_x, cur_o_nei = [], []
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [
                h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors
            ]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB,
                                                       self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o, mol_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend([0] * len(mol_batch))

        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = nn.ReLU()(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = nn.ReLU()(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
예제 #4
0
파일: jtnn_dec.py 프로젝트: futianfan/CORE
    def forward(self, mol_batch, x_tree_vecs, x_mol_vecs, origin_word):
        """
            mol_batch: Y;  where (X,Y), X is input, Y is target.  
        """
        pred_hiddens,pred_contexts,pred_targets = [],[],[]
        stop_hiddens,stop_contexts,stop_targets = [],[],[]
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], -1)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        #Predict Root
        batch_size = len(mol_batch)
        pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False) ### 0 
        h = {}

        for t in xrange(max_iter):
            prop_list = []
            batch_list = []
            for i,plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei,cur_o_nei = [],[]

            for node_x, real_y, _ in prop_list:
                #Neighbors for message passing (target not included)
                cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                #Neighbors for stop prediction (all neighbors)
                cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                #Current clique embedding
                cur_x.append(node_x.wid)

            #Clique embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)

            #Message passing
            cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)

            #Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            #Gather targets
            pred_target,pred_list = [],[]
            stop_target = []
            for i,m in enumerate(prop_list):
                node_x,node_y,direction = m
                x,y = node_x.idx,node_y.idx
                h[(x,y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i) 
                stop_target.append(direction)

            #Hidden states for stop prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x,cur_o], dim=1)
            stop_hiddens.append( stop_hidden )
            stop_contexts.append( cur_batch )
            stop_targets.extend( stop_target )
            
            #Hidden states for clique prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append( cur_batch )

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append( new_h.index_select(0, cur_pred) )
                pred_targets.extend( pred_target )

        #Last stop at root
        cur_x,cur_o_nei = [],[]
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x,cur_o], dim=1)
        stop_hiddens.append( stop_hidden )
        stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
        stop_targets.extend( [0] * len(mol_batch) )


        #Predict next clique
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        #pred_scores = self.attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, 'word')
        pred_scores = self.mixture_attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, origin_word)
        if torch.isnan(pred_scores).any():
            print "forward nan"

        '''
        pickle.dump((pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, pred_scores, pred_targets, origin_word),\
         open("tmp.pkl", "wb"))
        print "save ok"
        exit()'''
        
        '''
            pred_hiddens: (420, 300)  420 target
            pred_contexts: (420,) int  
            x_tree_vecs: (32, 21, 300)
            x_mol_vecs: (32, 28, 300)
            pred_scores: (420, 780)
            pred_targets: [517, 517, 516, 516, 523, 517, 517, 517, 517, 477, ...]   len()=420; 
            origin_word:  input 
import pickle
pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, pred_scores, pred_targets, origin_word = pickle.load(open("tmp.pkl", "rb"))
        '''

        pred_targets = create_var(torch.LongTensor(pred_targets))

        #pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        pred_loss = self.nllloss(torch.log(
                                      torch.max(pred_scores,\
                                        torch.ones_like(pred_scores) * minimum_threshold_before_log) \
                                          ),\
                                 pred_targets) / len(mol_batch)
        _,preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        #Predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu( self.U_i(stop_hiddens) )
        stop_scores = self.attention(stop_hiddens, stop_contexts, x_tree_vecs, x_mol_vecs, 'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))
        
        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()