コード例 #1
0
ファイル: LSTM.py プロジェクト: zhilangtaosha/EA
 def get_batch_data(self, x, y, doc_len, batch_size, keep_prob):
     for index in batch_index(len(y), batch_size, 1):
         feed_dict = {
             self.x: x[index],
             self.y: y[index],
             self.doc_len: doc_len[index],
             self.keep_prob: keep_prob,
         }
         yield feed_dict, len(index)
コード例 #2
0
 def get_batch_data(self, x, y, y_sen, sen_len, doc_len, aspect_id, batch_size, keep_prob1, keep_prob2, test= False):
     for index in batch_index(len(y), batch_size, 1, test):
         feed_dict = {
             self.x: x[index],
             self.y_doc: y[index],
             self.y_sen: y_sen[index],
             self.sen_len: sen_len[index],
             self.doc_len: doc_len[index],
             self.aspect_id: aspect_id[index],
             self.keep_prob1: keep_prob1,
             self.keep_prob2: keep_prob2,   
         }
         yield feed_dict, len(index)
コード例 #3
0
 def get_batch_data(self,
                    x,
                    y,
                    doc_len,
                    topic,
                    batch_size,
                    keep_prob,
                    test=False):
     for index in batch_index(len(y), batch_size, 1, test):
         feed_dict = {
             self.x: x[index],
             self.y: y[index],
             self.doc_len: doc_len[index],
             self.topic: topic[index],
             self.keep_prob: keep_prob,
         }
         yield feed_dict, len(index)