def get_test_batches_with_buckets_and_target(self, src_path, dst_path, tokens_per_batch): buckets = [(i) for i in range(50, 10000, 5)] def select_bucket(sl): for l1 in buckets: if sl < l1: return l1 raise Exception("The sequence is too long: ({})".format(sl)) uttid_target_map = {} for line in codecs.open(dst_path, 'r', 'utf-8'): line = line.strip() if line == '' or line is None: continue splits = re.split('\s+', line) uttid = splits[0].strip() target = splits[1:] uttid_target_map[uttid] = target logging.info('loaded dst_path=' + str(dst_path) + ',size=' + str(len(uttid_target_map))) caches = {} for bucket in buckets: caches[bucket] = [[], [], 0, 0] scp_reader = zark.ArkReader(src_path) count = 0 while True: uttid, input, loop = scp_reader.read_next_utt() if loop: break target = uttid_target_map[uttid] if target is None: logging.warn('uttid=' + str(uttid) + ',target is None') continue input_len = len(input) target_len = len(target) bucket = select_bucket(input_len) caches[bucket][0].append(input) caches[bucket][1].append(target) caches[bucket][2] += input_len caches[bucket][3] += target_len count = count + 1 if caches[bucket][2] > tokens_per_batch: feat_batch, feat_batch_mask = self._create_feat_batch(caches[bucket][0]) target_batch, target_batch_mask = self._create_target_batch(caches[bucket][1], self.dst2idx) yield feat_batch, target_batch caches[bucket] = [[], [], 0, 0] for bucket in buckets: if len(caches[bucket][0]) > 0: logging.info('get_test_batches_with_buckets, len(caches[bucket][0])=' + str(len(caches[bucket][0]))) feat_batch, feat_batch_mask = self._create_feat_batch(caches[bucket][0]) target_batch, target_batch_mask = self._create_target_batch(caches[bucket][1], self.dst2idx) yield feat_batch, target_batch logging.info('get_test_batches_with_buckets, loaded count=' + str(count))
def get_test_batches_with_buckets(self, src_path, tokens_per_batch): buckets = [(i) for i in range(50, 10000, 10)] def select_bucket(sl): for l1 in buckets: if sl < l1: return l1 raise Exception("The sequence is too long: ({})".format(sl)) caches = {} for bucket in buckets: caches[bucket] = [[], [], 0] # feats, uttids, count scp_reader = zark.ArkReader(src_path) count = 0 while True: uttid, input, loop = scp_reader.read_next_utt() if loop: break input_len = len(input) bucket = select_bucket(input_len) caches[bucket][0].append(input) caches[bucket][1].append(uttid) caches[bucket][2] += input_len count = count + 1 if caches[bucket][2] > tokens_per_batch: feat_batch, feat_batch_mask = self._create_feat_batch( caches[bucket][0]) yield feat_batch, caches[bucket][1] caches[bucket] = [[], [], 0] # Clean remain sentences. for bucket in buckets: if len(caches[bucket][0]) > 0: logging.info( 'get_test_batches_with_buckets, len(caches[bucket][0])=' + str(len(caches[bucket][0]))) feat_batch, feat_batch_mask = self._create_feat_batch( caches[bucket][0]) yield feat_batch, caches[bucket][1] logging.info('get_test_batches_with_buckets, loaded count=' + str(count))
def get_test_batches(self, src_path, batch_size): scp_reader = zark.ArkReader(src_path) cache = [] uttids = [] while True: uttid, feat, loop = scp_reader.read_next_utt() if loop: break cache.append(feat) uttids.append(uttid) if len(cache) >= batch_size: feat_batch, feat_batch_mask = self._create_feat_batch(cache) # yield feat_batch, feat_batch_mask, uttids yield feat_batch, uttids cache = [] uttids = [] if cache: feat_batch, feat_batch_mask = self._create_feat_batch(cache) # yield feat_batch, feat_batch_mask, uttids yield feat_batch, uttids
def get_training_batches_with_buckets(self, shuffle=True): """ Generate batches according to bucket setting. """ # buckets = [(i, i) for i in range(5, 1000000, 3)] buckets = [ (i, i) for i in range(self._config.bucket_min, self._config.bucket_max, self._config.bucket_step) ] def select_bucket(sl, dl): for l1, l2 in buckets: if sl < l1 and dl < l2: return l1, l2 raise Exception("The sequence is too long: ({}, {})".format( sl, dl)) # Shuffle the training files. src_path = self._config.train.src_path dst_path = self._config.train.dst_path max_length = self._config.train.max_length if shuffle: logging.info('Shuffle files %s and %s.' % (src_path, dst_path)) src_shuf_path, dst_shuf_path = self.shuffle([src_path, dst_path], self._config.model_dir) logging.info('Shuffled files %s and %s.' % (src_shuf_path, dst_shuf_path)) self._tmps.add(src_shuf_path) self._tmps.add(dst_shuf_path) else: src_shuf_path = src_path dst_shuf_path = dst_path caches = {} for bucket in buckets: caches[bucket] = [ [], [], 0, 0 ] # src sentences, dst sentences, src tokens, dst tokens uttid_target_map = {} for line in codecs.open(dst_shuf_path, 'r', 'utf-8'): line = line.strip() if line == '' or line is None: continue splits = re.split('\s+', line) uttid = splits[0].strip() target = splits[1:] uttid_target_map[uttid] = target logging.info('loaded dst_shuf_path=' + str(dst_shuf_path) + ',size=' + str(len(uttid_target_map))) num_random_caches = 5000 num_cache_max_length = 600 num_cache_min_length = 100 num_cache_target_min_length = 4 random_caches = [] count = 0 scp_reader = zark.ArkReader(src_shuf_path) while True: uttid, input, looped = scp_reader.read_next_utt() if looped: break target = uttid_target_map[uttid] if target is None: logging.warn('uttid=' + str(uttid) + ',target is None') continue input_len = len(input) target_len = len(target) if input_len > max_length or target_len > max_length: logging.warn('uttid=' + str(uttid) + ',input_len=' + str(input_len) + ' > max_length=' + str(max_length)) continue count = count + 1 if target_len == 0: continue bucket = select_bucket(input_len, target_len) caches[bucket][0].append(input) caches[bucket][1].append(target) caches[bucket][2] += input_len caches[bucket][3] += target_len if len(random_caches) < num_random_caches and num_cache_min_length <= input_len <= num_cache_max_length \ and target_len >= num_cache_target_min_length: random_caches.append([input, target]) if max(caches[bucket][2], caches[bucket][3]) > self._config.train.tokens_per_batch: feat_batch, feat_batch_mask = self._create_feat_batch( caches[bucket][0]) target_batch, target_batch_mask = self._create_target_batch( caches[bucket][1], self.dst2idx) # yield (feat_batch, feat_batch_mask, target_batch, target_batch_mask) yield (feat_batch, target_batch, len(caches[bucket][0])) caches[bucket] = [[], [], 0, 0] # Clean remain sentences. for bucket in buckets: # Ensure each device at least get one sample. if len(caches[bucket][0]) > 0: src_len = len(caches[bucket][0]) if self._config.min_count_in_bucket is None: default_min_count_in_bucket = 20 logging.info('min_count_in_bucket=' + str(self._config.min_count_in_bucket) + ',use default_min_count_in_bucket=' + str(default_min_count_in_bucket)) self._config.min_count_in_bucket = default_min_count_in_bucket left_count = self._config.min_count_in_bucket - src_len if left_count > 0: # append to self._config.train.num_gpus for idx in range(left_count): rand_idx = np.random.randint(0, num_random_caches) if rand_idx >= num_random_caches: rand_idx = 0 input, target = random_caches[rand_idx] caches[bucket][0].append(input) caches[bucket][1].append(target) caches[bucket][2] += len(input) caches[bucket][3] += len(target) dst_len = len(caches[bucket][0]) logging.info( 'get_training_batches_with_buckets, src_len=' + str(src_len) + ',dst_len=' + str(dst_len) + ',bucket=' + str(bucket) + ',max(caches[bucket][2], caches[bucket][3])=' + str(max(caches[bucket][2], caches[bucket][3]))) feat_batch, feat_batch_mask = self._create_feat_batch( caches[bucket][0]) target_batch, target_batch_mask = self._create_target_batch( caches[bucket][1], self.dst2idx) # yield (feat_batch, feat_batch_mask, target_batch, target_batch_mask) yield (feat_batch, target_batch, len(caches[bucket][0])) logging.info('loaded count=' + str(count)) # Remove shuffled files when epoch finished. if shuffle: os.remove(src_shuf_path) os.remove(dst_shuf_path) self._tmps.remove(src_shuf_path) self._tmps.remove(dst_shuf_path)