def generate_seq(self, seq_idx): data = self.data[seq_idx] return DatasetSeq( seq_idx=seq_idx, features=data["data"], targets={target: data[target] for target in self.target_list})
def _collect_single_seq(self, seq_idx): """ :param int seq_idx: sequence id :rtype: DatasetSeq """ if seq_idx >= len(self.seq_order): return None real_seq_index = self.seq_order[seq_idx] seq_name = self.all_seq_names[real_seq_index] curr_triplet = self.curr_epoch_triplets[seq_idx] targets = {} for id, sample in enumerate(curr_triplet): real_sample_seq_idx = sample sample_seq_name = self.all_seq_names[real_sample_seq_idx] sample_seq_file_index = self.file_indices[real_sample_seq_idx] norm_sample_seq_name = self._normalize_seq_name(sample_seq_name) for name, parsers in self.all_parsers.items(): targets['%s_%d' % (name, id)] = parsers[sample_seq_file_index].get_data( norm_sample_seq_name) features = targets['%s_%d' % (self.input_stream_name, 0)] return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_name, features=features, targets=targets)
def _collect_single_seq(self, seq_idx): """ :param int seq_idx: :rtype: DatasetSeq | None :returns DatasetSeq or None if seq_idx >= num_seqs. """ assert self.cur_seq_list is not None, "call init_seq_order" if seq_idx >= len(self.cur_seq_list): return None seq_tag = self.cur_seq_list[seq_idx] sub_seq_tags = seq_tag.split(self.seq_tag_delim) sub_seq_idxs = self.cur_sub_seq_idxs[seq_idx] assert len(sub_seq_tags) == len(sub_seq_idxs) features = {key: [] for key in self.get_data_keys()} if seq_idx == 0: # some extra check, but enough to do for first seq only sub_dataset_keys = self.sub_dataset.get_data_keys() for key in self.remove_in_between_postfix: assert key in sub_dataset_keys, "%s: remove_in_between_postfix key %r not in sub dataset data-keys %r" % ( self, key, sub_dataset_keys) for sub_seq_idx, sub_seq_tag in zip(sub_seq_idxs, sub_seq_tags): self.sub_dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1) sub_dataset_tag = self.sub_dataset.get_tag(sub_seq_idx) assert sub_dataset_tag == sub_seq_tag, "%s: expected tag %r for sub seq idx %i but got %r, part of seq %i %r" % ( self, sub_seq_tag, sub_seq_idx, sub_dataset_tag, seq_idx, seq_tag) for key in self.get_data_keys(): data = self.sub_dataset.get_data(sub_seq_idx, key) if key in self.remove_in_between_postfix and sub_seq_idx != sub_seq_idxs[-1]: assert data.ndim == 1 and data[-1] == self.remove_in_between_postfix[key] data = data[:-1] features[key].append(data) features = {key: numpy.concatenate(values, axis=0) for (key, values) in features.items()} return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
def generate_seq(self, seq_idx): seq_len = self.get_random_seq_len() input_seq = self.generate_input_seq(seq_len) output_seq = self.make_output_seq(input_seq) features = class_idx_seq_to_1_of_k(input_seq, num_classes=len(self._input_classes)) targets = numpy.array(output_seq) return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets)
def get_dataset_seq_for_name(self, name, seq_idx=-1): """ :param str name: :param int seq_idx: :rtype: DatasetSeq """ data = {key: d.read(name) for (key, d) in self.data.items()} # type: typing.Dict[str,numpy.ndarray] return DatasetSeq(seq_idx=seq_idx, seq_tag=name, features=data["data"], targets=data)
def generate_seq(self, seq_idx): seq_len = self.seq_len i1 = seq_idx i2 = i1 + seq_len * self.num_inputs features = np.array(range(i1, i2)).reshape((seq_len, self.num_inputs)) i1, i2 = i2, i2 + seq_len targets = np.array(range(i1, i2)) return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets)
def get_dataset_seq_for_name(self, name, seq_idx=-1): data = {key: d.read(name) for (key, d) in self.data.items() } # type: dict[str,numpy.ndarray] return DatasetSeq(seq_idx=seq_idx, seq_tag=name, features=data["data"], targets=data)
def _collect_single_seq(self, seq_idx): dataset_idx = self._get_dataset_for_seq_idx(seq_idx) dataset = self.datasets[dataset_idx] dataset_seq_idx = seq_idx + self.dataset_seq_idx_offsets[dataset_idx] seq_tag = dataset.get_tag(dataset_seq_idx) features = dataset.get_input_data(dataset_seq_idx) targets = {k: dataset.get_targets(k, dataset_seq_idx) for k in dataset.get_target_list()} return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
def generate_seq(self, seq_idx): """ :type seq_idx: int :rtype: DatasetSeq """ seq_len = self.get_random_seq_len() seq = [self.random.randint(0, self.nsymbols) for i in range(seq_len)] seq_np = numpy.array(seq, dtype="int8") return DatasetSeq(seq_idx=seq_idx, features=seq_np, targets={"classes": seq_np})
def _collect_single_seq(self, seq_idx): """ :type seq_idx: int :rtype: DatasetSeq """ seq_tag = self.seq_list_ordered[seq_idx] features = self._get_data(seq_idx, "data") targets = {target: self._get_data(seq_idx, target) for target in self.target_list} return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
def generate_seq(self, seq_idx): seq_len = self.seq_len i1 = seq_idx i2 = i1 + seq_len * self.num_inputs features = numpy.array([((i % self.input_max_value) + self.input_shift) * self.input_scale for i in range(i1, i2)]).reshape((seq_len, self.num_inputs)) i1, i2 = i2, i2 + seq_len targets = numpy.array([i % self.num_outputs["classes"][0] for i in range(i1, i2)]) return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets)
def _collect_single_seq(self, seq_idx): seq_name = self.get_tag(seq_idx) #print >> log.v5, "ClusteringDataset: _collect_single_seq: seq_name", seq_name data = { key: self.dataset.get_data(seq_idx=seq_idx, key=key) for key in self.dataset.get_data_keys() } data["cluster_idx"] = numpy.array([self.cluster_map[seq_name]], dtype=self.cluster_idx_dtype) return DatasetSeq(seq_idx=seq_idx, features=data["data"], targets=data)
def _collect_single_seq(self, seq_idx): if seq_idx >= self._num_seqs: return None line_nr = self._get_line_nr(seq_idx) features = self._get_data(key="data", line_nr=line_nr) targets = self._get_data(key="classes", line_nr=line_nr) assert features is not None and targets is not None return DatasetSeq(seq_idx=seq_idx, seq_tag="line-%i" % line_nr, features=features, targets=targets)
def _collect_single_seq(self, seq_idx): """Returns the sequence specified by the index seq_idx. Normalization is applied to the input features if mean and variance have been specified during dataset creating (see the constructor). :type seq_idx: int :rtype: DatasetSeq | None :returns: None if seq_idx >= num_seqs or the corresponding sequence. """ if self._seq_index_list is None: self.init_seq_order() if seq_idx >= len(self._seq_index_list): return None # map the seq_idx to the shuffled sequence indices shuf_seq_idx = self._seq_index_list[seq_idx] partition_offset = int( np.sum([ self._get_partition_size(i1) for i1 in range(self._current_partition) ])) shuf_seq_idx += partition_offset seqMapping = self._seqMap[shuf_seq_idx] fileIdx = seqMapping[0] datasetName = seqMapping[1] fileHandler = self._fileHandlers[fileIdx] inputFeatures = fileHandler['inputs'][datasetName][...] targets = None if 'outputs' in fileHandler: targets = fileHandler['outputs'][datasetName][...] # optional normalization if self._normData is not None: assert isinstance(self._normData, NormalizationData) if self._flag_normalizeInputs: inputFeatures = StereoHdfDataset._normalizeVector( inputFeatures, self._normData.inputMean, self._normData.inputVariance) if self._flag_normalizeTargets: targets = StereoHdfDataset._normalizeVector( targets, self._normData.outputMean, self._normData.outputVariance) # enforce float32 to enable Theano optimizations inputFeatures = inputFeatures.astype(np.float32) if (targets is not None) and targets.shape[1] > 1: targets = targets.astype(np.float32) elif targets.shape[1] == 1: targets = np.reshape(targets.astype(np.int32), (targets.shape[0], )) return DatasetSeq(seq_idx, inputFeatures, targets)
def producer_add_data(self, data, seq_tag=None): """ :param numpy.ndarray data: :param str|None seq_tag: """ with self.condition: if seq_tag is None: seq_tag = "seq-%i" % self.producer_seq_idx seq = DatasetSeq(features=data, seq_idx=self.producer_seq_idx, seq_tag=seq_tag) self.producer_seq_idx += 1 self.producer_data.append(seq) self.condition.notify()
def _add_data(self, data, original_tag): """ :type data: dict[str,numpy.ndarray] :type original_tag: str """ features = data["data"] if not self.added_data: seq_idx = 0 assert self.expected_load_seq_start == 0 else: seq_idx = self.added_data[-1].seq_idx + 1 tag = "%s.%i" % (original_tag, seq_idx) seq = DatasetSeq(seq_idx=seq_idx, features=features, targets=data, seq_tag=tag) self._num_timesteps_accumulated += seq.num_frames self.added_data += [seq]
def _collect_single_seq_from_buffer(self, wavFileId, seq_idx): """ returns the sequence specified by the index seq_idx :type wavFileId: int :type seq_idx: int :rtype: DatasetSeq | None :returns DatasetSeq or None if seq_idx >= num_seqs. """ inputFeatures = self._getInputFeatures(wavFileId) outputFeatures = self._getOutputFeatures(wavFileId) inputFeatures = inputFeatures.astype(np.float32) if outputFeatures is not None: outputFeatures = targets.astype(np.float32) return DatasetSeq(seq_idx, inputFeatures, outputFeatures)
def _collect_single_seq(self, seq_idx): """ :type seq_idx: int :rtype: DatasetSeq """ if not self.is_less_than_num_seqs(seq_idx): return None dataset_idx, dataset_seq_idx = self.dataset_seq_idxs[seq_idx] dataset_key = self.dataset_idxs[dataset_idx] dataset = self.datasets[dataset_key] seq_tag = dataset.get_tag(dataset_seq_idx) features = self._get_data(dataset_key, dataset_seq_idx, "data") targets = {target: self._get_data(dataset_key, dataset_seq_idx, target) for target in self.target_list} return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
def _collect_single_seq(self, seq_idx): """Returns the sequence specified by the index seq_idx. Normalization is applied to the input features if mean and variance have been specified during dataset creating (see the constructor). :type seq_idx: int :rtype: DatasetSeq | None :returns: None if seq_idx >= num_seqs or the corresponding sequence. """ if seq_idx >= self.num_seqs: return None seqMapping = self._seqMap[seq_idx] fileIdx = seqMapping[0] datasetName = seqMapping[1] fileHandler = self._fileHandlers[fileIdx] inputFeatures = fileHandler['inputs'][datasetName][...] targets = None if 'outputs' in fileHandler: targets = fileHandler['outputs'][datasetName][...] # optional normalization if self._normData is not None: assert isinstance(self._normData, NormalizationData) # inputs if self._flag_normalizeInputs: inputFeatures = StereoHdfDataset._normalizeVector( inputFeatures, self._normData.inputMean, self._normData.inputVariance, ) # outputs if self._flag_normalizeTargets: targets = StereoHdfDataset._normalizeVector( targets, self._normData.outputMean, self._normData.outputVariance, ) # enforce float32 to enable Theano optimizations inputFeatures = inputFeatures.astype(np.float32) if targets is not None: targets = targets.astype(np.float32) return DatasetSeq(seq_idx, inputFeatures, targets)
def _collect_single_seq(self, seq_idx): """ :type seq_idx: int :rtype: DatasetSeq """ if seq_idx >= len(self.seq_order): return None real_seq_index = self.seq_order[seq_idx] file_index = self.file_indices[real_seq_index] seq_name = self.all_seq_names[real_seq_index] norm_seq_name = self._normalize_seq_name(seq_name) targets = {name: parsers[file_index].get_data(norm_seq_name) for name, parsers in self.all_parsers.items()} features = targets[self.input_stream_name] return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_name, features=features, targets=targets)
def _collect_single_seq(self, seq_idx): """this method implements stacking the features :type seq_idx: int :param seq_idx: index of a sequence :rtype: DatasetSeq :return: DatasetSeq """ if seq_idx >= self.num_seqs: return None originalSeq = super(DatasetWithTimeContext, self)._collect_single_seq(seq_idx) inputFeatures = originalSeq.get_data('data') frames, bins = inputFeatures.shape leftContext = deque() rightContext = deque() inFeatWithContext = [] for i in range(self._tau): leftContext.append(np.zeros(bins)) if i + 1 < frames: rightContext.append(inputFeatures[i + 1, ...]) else: rightContext.append(np.zeros(bins)) for t in range(frames): f = inputFeatures[t, ...] newFeature = np.concatenate([ np.concatenate(leftContext, axis=0), f, np.concatenate(rightContext, axis=0) ], axis=0) inFeatWithContext.append(newFeature) leftContext.popleft() leftContext.append(f) rightContext.popleft() if t + 1 + self._tau < frames: rightContext.append(inputFeatures[t + 1 + self._tau, ...]) else: rightContext.append(np.zeros(bins)) inputFeatures = np.array(inFeatWithContext) targets = None if 'classes' in originalSeq.get_data_keys(): targets = originalSeq.get_data('classes') return DatasetSeq(seq_idx, inputFeatures, targets)
def _collect_single_seq(self, seq_idx): """ :param int seq_idx: :rtype: DatasetSeq | None :returns DatasetSeq or None if seq_idx >= num_seqs. """ assert self.cur_seq_list is not None, "call init_seq_order" if seq_idx >= len(self.cur_seq_list): return None seq_tag = self.cur_seq_list[seq_idx] sub_seq_idxs = self.cur_sub_seq_idxs[seq_idx] features = {key: [] for key in self.get_data_keys()} for sub_seq_idx in sub_seq_idxs: self.sub_dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1) for key in self.get_data_keys(): data = self.sub_dataset.get_data(seq_idx, key) features[key].append(data) features = {key: numpy.concatenate(values, axis=0) for (key, values) in features.items()} return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
def _add_cache_seq(self, seq_idx, features, targets): last_seq_idx = self._get_cache_last_seq_idx() assert seq_idx == last_seq_idx + 1 self.cached_seqs += [DatasetSeq(seq_idx, features, targets)]
def addNewData(self, features, targets=None, segmentName=None): """ Adds a new seq. This is called via the Sprint main thread. :param numpy.ndarray features: format (input-feature,time) (via Sprint) :param dict[str,numpy.ndarray|str] targets: format (time) (idx of output-feature) :returns the sorted seq index :rtype: int """ # is in format (feature,time) assert self.num_inputs == features.shape[0] T = features.shape[1] # must be in format: (time,feature) features = features.transpose() assert features.shape == (T, self.num_inputs) if targets is None: targets = {} if not isinstance(targets, dict): targets = {"classes": targets} if "classes" in targets: # 'classes' is always the alignment assert targets["classes"].shape == ( T, ), "Number of targets %s does not equal to number of features %s" % ( targets["classes"].shape, (T, )) # is in format (time,) # Maybe convert some targets. if self.target_maps: for key, tmap in self.target_maps.items(): assert key in targets v = tmap[targets[key]] v = numpy.asarray(v) if v.ndim == 0: v = numpy.zeros( (T, ), dtype=v.dtype) + v # add time dimension targets[key] = v # Maybe remove some targets. for key in self._target_black_list: if key in targets: del targets[key] # Check if all targets are valid. for key, v in sorted(list(targets.items())): if isinstance(v, numpy.ndarray): continue # ok if isinstance(v, unicode): v = v.encode("utf8") if isinstance(v, (str, bytes)): v = map(ord, v) v = numpy.array(v, dtype="uint8") targets[key] = v continue print >> log.v3, "SprintDataset, we will ignore the target %r because it is not a numpy array: %r" % ( key, v) self._target_black_list += [key] del targets[key] with self.lock: # This gets called in the Sprint main thread. # If this is used together with SprintInterface.getSegmentList(), we are always in a state where # we just yielded a segment name, thus we are always in a Sprint epoch and thus ready for data. assert self.ready_for_data assert not self.reached_final_seq assert not self.sprintFinalized seq_idx = self.next_seq_to_be_added if self.predefined_seq_list_order: # Note: Only in ExternSprintDataset, we can reliably set the seq order for now. assert self.predefined_seq_list_order[ seq_idx] == segmentName, "seq-order not as expected" self.next_seq_to_be_added += 1 self._num_timesteps += T self.cond.notify_all() if seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMax: print >> log.v5, "SprintDataset addNewData: seq=%i, len=%i. Cache filled, waiting to get loaded..." % ( seq_idx, T) while seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMin: assert not self.reached_final_seq assert seq_idx + 1 == self.next_seq_to_be_added self.cond.wait() self.added_data += [ DatasetSeq(seq_idx, features, targets, seq_tag=segmentName) ] self.cond.notify_all() return seq_idx
def _collect_single_seq(self, seq_idx): """ :type seq_idx: int :rtype: DatasetSeq | None :returns DatasetSeq or None if seq_idx >= num_seqs. """ while True: if self.next_orth_idx >= len(self.orths_epoch): assert self.next_seq_idx <= seq_idx, "We expect that we iterate through all seqs." if self.num_skipped > 0: print("LmDataset: reached end, skipped %i sequences" % self.num_skipped) return None assert self.next_seq_idx == seq_idx, "We expect that we iterate through all seqs." orth = self.orths_epoch[self.seq_order[self.next_orth_idx]] self.next_orth_idx += 1 if orth == "</s>": continue # special sentence end symbol. empty seq, ignore. if self.seq_gen: try: phones = self.seq_gen.generate_seq(orth) except KeyError as e: if self.log_skipped_seqs: print( "LmDataset: skipping sequence %r because of missing lexicon entry: %s" % (orth, e), file=log.v4) self._reduce_log_skipped_seqs() if self.error_on_invalid_seq: raise Exception( "LmDataset: invalid seq %r, missing lexicon entry %r" % (orth, e)) self.num_skipped += 1 continue # try another seq data = self.seq_gen.seq_to_class_idxs(phones, dtype=self.dtype) elif self.orth_symbols: orth_syms = parse_orthography(orth, **self.parse_orth_opts) while True: orth_syms = sum( [self.orth_replace_map.get(s, [s]) for s in orth_syms], []) i = 0 while i < len(orth_syms) - 1: if orth_syms[i:i + 2] == [" ", " "]: orth_syms[i:i + 2] = [" "] # collapse two spaces else: i += 1 if self.auto_replace_unknown_symbol: try: map(self.orth_symbols_map.__getitem__, orth_syms) except KeyError as e: orth_sym = e.message if self.log_auto_replace_unknown_symbols: print( "LmDataset: unknown orth symbol %r, adding to orth_replace_map as %r" % (orth_sym, self.unknown_symbol), file=log.v3) self._reduce_log_auto_replace_unknown_symbols() self.orth_replace_map[orth_sym] = [ self.unknown_symbol ] if self.unknown_symbol is not None else [] continue # try this seq again with updated orth_replace_map break self.num_unknown += orth_syms.count(self.unknown_symbol) if self.word_based: orth_debug_str = repr(orth_syms) else: orth_debug_str = repr("".join(orth_syms)) try: data = numpy.array(map(self.orth_symbols_map.__getitem__, orth_syms), dtype=self.dtype) except KeyError as e: if self.log_skipped_seqs: print( "LmDataset: skipping sequence %s because of missing orth symbol: %s" % (orth_debug_str, e), file=log.v4) self._reduce_log_skipped_seqs() if self.error_on_invalid_seq: raise Exception( "LmDataset: invalid seq %s, missing orth symbol %s" % (orth_debug_str, e)) self.num_skipped += 1 continue # try another seq else: assert False targets = {} for i in range(self.add_random_phone_seqs): assert self.seq_gen # not implemented atm for orths phones = self.seq_gen.generate_garbage_seq( target_len=data.shape[0]) targets["random%i" % i] = self.seq_gen.seq_to_class_idxs( phones, dtype=self.dtype) if self.add_delayed_seq_data: targets["delayed"] = numpy.concatenate(([ self.orth_symbols_map[self.delayed_seq_data_start_symbol] ], data[:-1])).astype(self.dtype) assert targets["delayed"].shape == data.shape self.next_seq_idx = seq_idx + 1 return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets)
def add_new_data(self, features, targets=None, segment_name=None): """ Adds a new seq. This is called via the Sprint main thread. :param numpy.ndarray features: format (input-feature,time) (via Sprint) :param dict[str,numpy.ndarray|str] targets: format (time) (idx of output-feature) :param str|None segment_name: :returns the sorted seq index :rtype: int """ # is in format (feature,time) assert self.num_inputs == features.shape[0] num_frames = features.shape[1] # must be in format: (time,feature) features = features.transpose() assert features.shape == (num_frames, self.num_inputs) if self.input_stddev != 1: features /= self.input_stddev if self.window > 1: features = self.sliding_window(features) assert features.shape == (num_frames, self.num_inputs * self.window) if targets is None: targets = {} if not isinstance(targets, dict): targets = {"classes": targets} if "classes" in targets: # 'classes' is always the alignment assert targets["classes"].shape == (num_frames,), ( # is in format (time,) "Number of targets %s does not equal to number of features %s" % (targets["classes"].shape, (num_frames,))) if "orth" in targets: targets["orth"] = targets["orth"].decode("utf8").strip() if "orth" in targets and self.orth_post_process: targets["orth"] = self.orth_post_process(targets["orth"]) if self.bpe: assert "orth" in targets orth = targets["orth"] assert isinstance(orth, (str, unicode)) assert "bpe" not in targets targets["bpe"] = numpy.array(self.bpe.get_seq(orth), dtype="int32") if self.orth_vocab: assert "orth" in targets orth = targets["orth"] assert isinstance(orth, (str, unicode)) assert "orth_classes" not in targets targets["orth_classes"] = numpy.array(self.orth_vocab.get_seq(orth), dtype="int32") # Maybe convert some targets. if self.target_maps: for key, target_map in self.target_maps.items(): assert key in targets v = target_map[targets[key]] v = numpy.asarray(v) if v.ndim == 0: v = numpy.zeros((num_frames,), dtype=v.dtype) + v # add time dimension targets[key] = v # Maybe remove some targets. for key in self._target_black_list: if key in targets: del targets[key] # Check if all targets are valid. for key, v in sorted(list(targets.items())): if isinstance(v, numpy.ndarray): continue # ok if isinstance(v, unicode): v = v.encode("utf8") if isinstance(v, (str, bytes)): if PY3: assert isinstance(v, bytes) v = list(v) else: v = list(map(ord, v)) v = numpy.array(v, dtype="uint8") targets[key] = v if self.str_add_final_zero: v = numpy.append(v, numpy.array([0], dtype=v.dtype)) assert key + "0" not in targets targets[key + "0"] = v continue print("%s, we will ignore the target %r because it is not a numpy array: %r" % (self, key, v), file=log.v3) self._target_black_list += [key] del targets[key] with self.lock: # This gets called in the Sprint main thread. # If this is used together with SprintInterface.getSegmentList(), we are always in a state where # we just yielded a segment name, thus we are always in a Sprint epoch and thus ready for data. assert self.ready_for_data assert not self.reached_final_seq assert not self.sprintFinalized seq_idx = self.next_seq_to_be_added if self.predefined_seq_list_order: # Note: Only in ExternSprintDataset, we can reliably set the seq order for now. assert seq_idx < len(self.predefined_seq_list_order), "seq_idx %i, expected predef num seqs %i" % ( seq_idx, len(self.predefined_seq_list_order)) expected_seq_name = self.predefined_seq_list_order[seq_idx] if expected_seq_name != segment_name: if segment_name in self.predefined_seq_list_order: raise Exception("seq_idx %i expected to be tag %r but got tag %r; tag %r is at idx %i" % ( seq_idx, expected_seq_name, segment_name, segment_name, self.predefined_seq_list_order.index(segment_name))) raise Exception("seq_idx %i expected to be tag %r but got tag %r; tag %r not found" % ( seq_idx, expected_seq_name, segment_name, segment_name)) self.next_seq_to_be_added += 1 self._num_timesteps += num_frames self.cond.notify_all() if seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMax: print("%s add_new_data: seq=%i, len=%i. Cache filled, waiting to get loaded..." % ( self, seq_idx, num_frames), file=log.v5) while seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMin: assert not self.reached_final_seq assert seq_idx + 1 == self.next_seq_to_be_added self.cond.wait() self.added_data += [DatasetSeq(seq_idx, features, targets, seq_tag=segment_name)] self.cond.notify_all() return seq_idx