示例#1
0
 def get_batch_generator(self):
     if self.is_train:
         while True:
             X1, X1_len, X2, X2_len, Y, ID_pairs = self.get_batch()
             if self.config['use_dpool']:
                 yield ({
                     'query':
                     X1,
                     'query_len':
                     X1_len,
                     'doc':
                     X2,
                     'doc_len':
                     X2_len,
                     'dpool_index':
                     DynamicMaxPooling.dynamic_pooling_index(
                         X1_len, X2_len, self.config['text1_maxlen'],
                         self.config['text2_maxlen'])
                 }, Y)
             else:
                 yield ({
                     'query': X1,
                     'query_len': X1_len,
                     'doc': X2,
                     'doc_len': X2_len
                 }, Y)
     else:
         while self.point + self.batch_size <= self.total_rel_num:
             X1, X1_len, X2, X2_len, Y, ID_pairs = self.get_batch(
                 randomly=False)
             if self.config['use_dpool']:
                 yield ({
                     'query':
                     X1,
                     'query_len':
                     X1_len,
                     'doc':
                     X2,
                     'doc_len':
                     X2_len,
                     'dpool_index':
                     DynamicMaxPooling.dynamic_pooling_index(
                         X1_len, X2_len, self.config['text1_maxlen'],
                         self.config['text2_maxlen']),
                     'ID':
                     ID_pairs
                 }, Y)
             else:
                 yield ({
                     'query': X1,
                     'query_len': X1_len,
                     'doc': X2,
                     'doc_len': X2_len,
                     'ID': ID_pairs
                 }, Y)
 def get_batch_generator(self):
     for X1, X1_len, X2, X2_len, Y, ID_pairs, list_counts in self.get_batch(
     ):
         if self.config['use_dpool']:
             yield ({
                 'query':
                 X1,
                 'query_len':
                 X1_len,
                 'doc':
                 X2,
                 'doc_len':
                 X2_len,
                 'dpool_index':
                 DynamicMaxPooling.dynamic_pooling_index(
                     X1_len, X2_len, self.config['text1_maxlen'],
                     self.config['text2_maxlen']),
                 'ID':
                 ID_pairs,
                 'list_counts':
                 list_counts
             }, Y)
         else:
             yield ({
                 'query': X1,
                 'query_len': X1_len,
                 'doc': X2,
                 'doc_len': X2_len,
                 'ID': ID_pairs,
                 'list_counts': list_counts
             }, Y)
示例#3
0
 def get_batch_generator(self):
     while True:
         X1, X1_len, X2, X2_len, Y = self.get_batch()
         if self.config['use_dpool']:
             yield ({
                 'query':
                 X1,
                 'query_len':
                 X1_len,
                 'doc':
                 X2,
                 'doc_len':
                 X2_len,
                 'dpool_index':
                 DynamicMaxPooling.dynamic_pooling_index(
                     X1_len, X2_len, self.config['text1_maxlen'],
                     self.config['text2_maxlen'])
             }, Y)
         else:
             yield ({
                 'query': X1,
                 'query_len': X1_len,
                 'doc': X2,
                 'doc_len': X2_len
             }, Y)
示例#4
0
 def get_batch_generator(self):
     while True:
         sample = self.get_batch()
         if not sample:
             break
         X1, X1_len, X2, X2_len, Y, ID_pairs = sample
         if self.config['use_dpool']:
             yield ({
                 'query':
                 X1,
                 'query_len':
                 X1_len,
                 'doc':
                 X2,
                 'doc_len':
                 X2_len,
                 'dpool_index':
                 DynamicMaxPooling.dynamic_pooling_index(
                     X1_len, X2_len, self.config['text1_maxlen'],
                     self.config['text2_maxlen']),
                 'ID':
                 ID_pairs
             }, Y)
         else:
             yield ({
                 'query': X1,
                 'query_len': X1_len,
                 'doc': X2,
                 'doc_len': X2_len,
                 'ID': ID_pairs
             }, Y)
示例#5
0
 def get_dpool_index(self, _len1, _len2):
     '''
     get dynamic pooling index
     @param _len1: int length of text1 terms
     @param _len2: int length of text2 terms
     @return: np.array(index)
     '''
     _dpool_index_arr = DynamicMaxPooling.dynamic_pooling_index([_len1,], [_len2,], self.config['text1_maxlen'], self.config['text2_maxlen'])
     return _dpool_index_arr
示例#6
0
 def get_batch_generator(self):
     while True:
         X1, XP1, X1_len, XP1_len, X2, XP2, X2_len, XP2_len, Y = self.get_batch(
         )
         # print('shapes: X1:{}, XP1:{}, X2:{}, XPS:{}, Y:{}'.format(X1.shape, XP1.shape, X2.shape, XP2.shape, Y.shape))
         if self.config['use_dpool']:
             yield ({
                 'query':
                 X1,
                 'query_pos':
                 XP1,
                 'query_len':
                 X1_len,
                 'query_pos_len':
                 XP1_len,
                 'doc':
                 X2,
                 'doc_pos':
                 XP2,
                 'doc_len':
                 X2_len,
                 'doc_pos_len':
                 XP2_len,
                 'dpool_index':
                 DynamicMaxPooling.dynamic_pooling_index(
                     X1_len, X2_len, self.config['text1_maxlen'],
                     self.config['text2_maxlen']),
                 'dpool_pos_index':
                 DynamicMaxPooling.dynamic_pooling_index(
                     XP1_len, XP2_len, self.config['pos1_maxlen'],
                     self.config['pos2_maxlen'])
             }, Y)
         else:
             yield ({
                 'query': X1,
                 'query_pos': XP1,
                 'query_len': X1_len,
                 'query_pos_len': XP1_len,
                 'doc': X2,
                 'doc_pos': XP2,
                 'doc_len': X2_len,
                 'doc_pos_len': XP2_len
             }, Y)
示例#7
0
 def get_batch_generator(self):
     for X1, X1_len, X2, X2_len, Y, ID_pairs, list_counts in self.get_batch():
         if self.config['use_dpool']:
             yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'dpool_index': DynamicMaxPooling.dynamic_pooling_index(X1_len, X2_len, self.config['text1_maxlen'], self.config['text2_maxlen']), 'ID': ID_pairs, 'list_counts': list_counts}, Y)
         else:
             yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'ID': ID_pairs, 'list_counts': list_counts}, Y)
示例#8
0
 def get_batch_generator(self):
     while True:
         sample = self.get_batch()
         if not sample:
             break
         X1, X1_len, X2, X2_len, Y, ID_pairs = sample
         if self.config['use_dpool']:
             yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'dpool_index': DynamicMaxPooling.dynamic_pooling_index(X1_len, X2_len, self.config['text1_maxlen'], self.config['text2_maxlen']), 'ID':ID_pairs}, Y)
         else:
             yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'ID':ID_pairs}, Y)
示例#9
0
 def get_batch_generator(self):
     while True:
         X1, X1_len, X2, X2_len, Y = self.get_batch()
         if self.config['use_dpool']:
             yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'dpool_index': DynamicMaxPooling.dynamic_pooling_index(X1_len, X2_len, self.config['text1_maxlen'], self.config['text2_maxlen'])}, Y)
         else:
             yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len}, Y)
示例#10
0
 def get_batch_generator(self):
     if self.is_train:
         while True:
             X1, X1_len, X2, X2_len, Y, ID_pairs = self.get_batch()
             if self.config['use_dpool']:
                 yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'dpool_index': DynamicMaxPooling.dynamic_pooling_index(X1_len, X2_len, self.config['text1_maxlen'], self.config['text2_maxlen'])}, Y)
             else:
                 yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len}, Y)
     else:
         while self.point + self.batch_size <= self.total_rel_num:
             X1, X1_len, X2, X2_len, Y, ID_pairs = self.get_batch(randomly = False)
             if self.config['use_dpool']:
                 yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'dpool_index': DynamicMaxPooling.dynamic_pooling_index(X1_len, X2_len, self.config['text1_maxlen'], self.config['text2_maxlen']), 'ID':ID_pairs}, Y)
             else:
                 yield ({'query': X1, 'query_len': X1_len, 'doc': X2, 'doc_len': X2_len, 'ID':ID_pairs}, Y)