def __getitem__(self, batch_index): ''' get batch_index's batch data ''' assert self._file_suffix == '.npy' logging.debug(f''' batch_index: {batch_index}, num_batches: {len(self)}, num exapmples: {self.num_examples}, label dict: {self.classes}''') indexes = self.indexes[batch_index * self.batch_size:(batch_index + 1) * self.batch_size] #logging.info(f"examples meta: {self.examples_meta}") batch_meta = [self.examples_meta[i] for i in indexes] if self.shuffle: random.shuffle(batch_meta) logging.debug(f"batch metah: {batch_meta}") feats = [] labels = [] filenames = [] for _, (filename, label, seg) in enumerate(batch_meta): feat = np.load(filename) # shape : [nframe, feat_size, 3] feat = feat_lib.add_delta_delta(feat, self._feature_size, order=2) if self.feature_shape[-1] == 1: feat = feat[:, :, 0:1] seg = list(map(self.sample_to_frame, seg)) if seg[2]: # need padding feat = np.pad(feat, [(0, seg[2]), (0, 0), (0, 0)], mode='constant') feat = feat[seg[0]:seg[1], :, :] assert len(feat) == self.sample_to_frame( self.example_len), "{} {} {} {} {} {}".format( filename, seg, len(feat), self.example_len, self.sample_to_frame(self.example_len), seg[2]) # convert string label to int label labelid = self.class_id(label) feats.append(feat) filenames.append(filename) labels.append(labelid) features = { 'inputs': np.array(feats, dtype=np.float64), 'labels': np.array(labels, dtype=np.int32), } one_hot_label = np.array(labels, dtype=np.int32) one_hot_label = tf.keras.utils.to_categorical(one_hot_label, num_classes=len( self.classes)) return features, one_hot_label
def generate_data(self): ''' generate one example''' use_text = self.taskconf['text']['enable'] # total files total = len(self._train_by_filename.values()) self._epoch += 1 # epcoh from 1 batch = [] np.random.shuffle(self.data_items) for i, (filename, examples) in enumerate(self.data_items): #logging.info("example info", filename, examples) # convert txt to ids if use_text: text = _load_text('.'.join(filename.split('.')[:-1])) text2id = self._word_table_lookup(text) else: text2id = np.array([0] * self._max_text_len) # gen audio or load feat if self._file_suffix == '.wav': sr, raw_samples = feat_lib.load_wav(filename) #pylint: disable=invalid-name for label, seg, clip_id in examples: # examples of one file samples = raw_samples if seg[2]: samples = np.pad(samples, [0, seg[2]], mode='constant') samples = samples[seg[0]:seg[1]] assert len(samples) == self.example_len, "{} {}".format( filename, seg) labelid = self.class_id(label) if self.use_distilling: soft_label = self.teacher(feat) else: class_num = self.taskconf['classes']['num'] soft_label = [0] * class_num if use_text: if clip_id == 0: # only add into batch when meet the first clip batch.append((samples, text2id, labelid, filename, clip_id, soft_label)) else: batch.append((samples, text2id, labelid, filename, clip_id, soft_label)) else: feat = np.load(filename) # shape : [nframe, feat_size, 3] if self._feature_type: fbank = feat_lib.add_delta_delta(feat, self._feature_size, order=2) if self._input_channels == 1: fbank = fbank[:, :, 0:1] else: fbank = feat_lib.delta_delta(feat) for label, seg, clip_id in examples: feat = fbank #logging.info("feat shape: {}".format(feat.shape)) seg = list(map(self.sample_to_frame, seg)) if seg[2]: # need padding feat = np.pad(feat, [(0, seg[2]), (0, 0), (0, 0)], mode='constant') feat = feat[seg[0]:seg[1], :, :] assert len(feat) == self.sample_to_frame( self.example_len), "{} {} {} {} {} {}".format( filename, seg, len(feat), self.example_len, self.sample_to_frame(self.example_len), seg[2]) if self.use_distilling: soft_label = self.teacher(feat) else: class_num = self.taskconf['classes']['num'] soft_label = [0] * class_num # convert string label to int label labelid = self.class_id(label) if use_text: if clip_id == 0: # only add into batch when meet the first clip batch.append((feat, text2id, labelid, filename, clip_id, soft_label)) else: batch.append((feat, text2id, labelid, filename, clip_id, soft_label)) #if i % 100000: # logging.info('epoch:{} iter exmaple:{} total:{} : {:.2f}%'.format( # self._epoch, i, total, i * 100 / total)) for inputs, texts, label, filepath, clip_id, soft_label in batch: yield inputs, texts, label, filepath, clip_id, soft_label batch.clear() logging.info("Out of range") raise StopIteration #pylint: disable=stop-iteration-return