def __init__(self, anno_files_path, vocab_path, vocab_size, vf_dir, vf_size, flag_shuffle=True): self.cu = CommonUtiler() self.anno_files_path = anno_files_path self.vocab_path = vocab_path self.vocab, _ = self.cu.load_vocabulary(vocab_path) assert len(self.vocab) == vocab_size assert self.vocab['<pad>'] == 0 self.vf_dir = vf_dir self.vf_size = vf_size self.flag_shuffle = flag_shuffle self._load_data()
class mRNNCocoBucketDataProvider(object): """mRNN TensorFlow Data Provider with Buckets on MS COCO.""" def __init__(self, anno_files_path, vocab_path, vocab_size, vf_dir, vf_size, flag_shuffle=True): self.cu = CommonUtiler() self.anno_files_path = anno_files_path self.vocab_path = vocab_path self.vocab, _ = self.cu.load_vocabulary(vocab_path) assert len(self.vocab) == vocab_size assert self.vocab['<pad>'] == 0 self.vf_dir = vf_dir self.vf_size = vf_size self.flag_shuffle = flag_shuffle self._load_data() def generate_batches(self, batch_size, buckets): """Return a list generator of mini-batches of training data.""" # create Batches batches = [] for max_seq_len in buckets: batches.append( Batch(batch_size, max_seq_len, self.vf_size, self.vocab['<bos>'])) # shuffle if necessary if self.flag_shuffle: np.random.shuffle(self._data_pointer) # scan data queue for ind_i, ind_s in self._data_pointer: sentence = self._data_queue[ind_i]['sentences'][ind_s] visual_features = self._data_queue[ind_i]['visual_features'] if len(sentence) >= buckets[-1]: feed_res = batches[-1].feed_and_vomit(visual_features, sentence) ind_buc = len(buckets) - 1 else: for (ind_b, batch) in enumerate(batches): if len(sentence) < batch.max_seq_len: feed_res = batches[ind_b].feed_and_vomit( visual_features, sentence) ind_buc = ind_b break if feed_res: yield (ind_buc, ) + feed_res batches[ind_buc].empty() def _load_data(self, verbose=True): logger.info('Loading data') vocab = self.vocab self._data_queue = [] self._data_pointer = [] ind_img = 0 num_failed = 0 for anno_file_path in self.anno_files_path: annos = np.load(anno_file_path).tolist() for (ind_a, anno) in enumerate(annos): data = {} # Load visual features feat_path = os.path.join( self.vf_dir, anno['file_path'], anno['file_name'].split('.')[0] + '.txt') if os.path.exists(feat_path): vf = np.loadtxt(feat_path) else: num_failed += 1 continue data['visual_features'] = vf # Encode sentences data['sentences'] = [] for (ind_s, sentence) in enumerate(anno['sentences']): sentence_encode = self.cu.encode_sentence( sentence, vocab, flag_add_bos=False) self._data_pointer.append((ind_img, ind_s)) data['sentences'].append(np.array(sentence_encode)) self._data_queue.append(data) ind_img += 1 if verbose and (ind_a + 1) % 5000 == 0: logger.info('Load %d/%d annotation from file %s', ind_a + 1, len(annos), anno_file_path) logger.info( 'Load %d images, %d sentences from %d files, %d image failed', len(self._data_queue), len(self._data_pointer), len(self.anno_files_path), num_failed)