def iterate_seqs(self, chunk_size=None, chunk_step=None, used_data_keys=None): """ Takes chunking into consideration. :param int chunk_size: :param int chunk_step: :param set(str)|None used_data_keys: :return: generator which yields tuples (seq index, seq start, seq end) :rtype: list[(int,NumbersDict,NumbersDict)] """ if chunk_size is None: chunk_size = self.chunk_size if chunk_step is None: chunk_step = self.chunk_step s = 0 while self.is_less_than_num_seqs(s): length = self.get_seq_length(s) if chunk_size == 0: yield (s, length.constant_like(0), length) else: if used_data_keys is not None: length = NumbersDict( {k: length[k] for k in used_data_keys}) t = length.constant_like(0) default_key = "data" # There are usually the 'data' (input) and 'classes' (targets) data-keys in `length` but there can be others. # We expect them all of the same length so that we can do chunking. # In case that some length is 0 or 1, # we treat it special and always return the full seq repeated for every chunk. keys_with_full_seqs = [] for key in length.keys(): if length[key] == length[default_key]: continue # ok if length[key] <= 1: keys_with_full_seqs.append(key) continue raise Exception( "Chunking with multiple data-keys of different length: %r" % length) while length[default_key] > t[default_key]: chunk_start = NumbersDict(t) chunk_end = NumbersDict.min([t + chunk_size, length]) for key in keys_with_full_seqs: chunk_start[key] = 0 chunk_end[key] = length[key] if length.value is None: chunk_start.value = None chunk_end.value = None yield (s, chunk_start, chunk_end) t += chunk_step if length[default_key] - t[ default_key] <= self.min_chunk_size: break s += 1
def _generate_batches(self, recurrent_net, batch_size, max_seqs=-1, seq_drop=0.0, max_seq_length=sys.maxsize, used_data_keys=None): """ :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1. Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated. :param int batch_size: Max number of frames in one batch. :param int max_seqs: Max number of seqs per batch. :param set(str)|None used_data_keys: """ if batch_size == 0: batch_size = sys.maxsize assert batch_size > 0 if max_seqs == -1: max_seqs = float('inf') assert max_seqs > 0 assert seq_drop <= 1.0 chunk_size = self.chunk_size chunk_step = self.chunk_step if not recurrent_net: if chunk_size != 0: print("Non-recurrent network, chunk size %i:%i ignored" % (chunk_size, chunk_step), file=log.v4) chunk_size = 0 batch = Batch() for seq_idx, t_start, t_end in self._iterate_seqs( chunk_size=chunk_size, chunk_step=chunk_step, used_data_keys=used_data_keys): if recurrent_net: length = t_end - t_start if max_seq_length < 0 and length['classes'] > -max_seq_length: continue elif max_seq_length > 0 and length.max_value( ) > max_seq_length: continue if length.max_value() > batch_size: print( "warning: sequence length (%i) larger than limit (%i)" % (length.max_value(), batch_size), file=log.v4) if self.rnd_seq_drop.random() < seq_drop: continue dt, ds = batch.try_sequence_as_slice(length) if ds > 1 and ((dt * ds).max_value() > batch_size or ds > max_seqs): yield batch batch = Batch() batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length) else: # Not recurrent. while t_start.max_value() < t_end.max_value(): length = t_end - t_start num_frames = NumbersDict.min([ length, batch_size - batch.get_all_slices_num_frames() ]) assert num_frames.max_value() > 0 batch.add_frames(seq_idx=seq_idx, seq_start_frame=t_start, length=num_frames) if batch.get_all_slices_num_frames( ) >= batch_size or batch.get_num_seqs() > max_seqs: yield batch batch = Batch() t_start += num_frames if batch.get_all_slices_num_frames() > 0: yield batch
def _generate_batches(self, recurrent_net, batch_size, max_seqs=-1, max_seq_length=sys.maxsize, min_seq_length=0, seq_drop=0.0, max_total_num_seqs=-1, used_data_keys=None): """ :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1. Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated. :param int batch_size: Max number of frames in one batch. :param int max_seqs: Max number of seqs per batch. :param int max_total_num_seqs: :param int|dict[str,int]|NumbersDict max_seq_length: :param set(str)|None used_data_keys: """ if batch_size == 0: batch_size = sys.maxsize assert batch_size > 0 if max_seqs == -1: max_seqs = float('inf') if not max_seq_length: max_seq_length = sys.maxsize if isinstance(max_seq_length, int) and max_seq_length < 0: max_seq_length = {"classes": -max_seq_length} max_seq_length = NumbersDict(max_seq_length) min_seq_length = NumbersDict(min_seq_length) assert max_seqs > 0 assert seq_drop <= 1.0 if not max_total_num_seqs or max_total_num_seqs < 0: max_total_num_seqs = float("inf") chunk_size = self.chunk_size chunk_step = self.chunk_step if not recurrent_net: if chunk_size != 0: print("Non-recurrent network, chunk size %s:%s ignored" % (chunk_size, chunk_step), file=log.v4) chunk_size = 0 batch = Batch() ctx_lr = self._get_context_window_left_right() total_num_seqs = 0 last_seq_idx = -1 for seq_idx, t_start, t_end in self.iterate_seqs( chunk_size=chunk_size, chunk_step=chunk_step, used_data_keys=used_data_keys): if seq_idx != last_seq_idx and total_num_seqs >= max_total_num_seqs: break if ctx_lr: t_start -= ctx_lr[0] t_end += ctx_lr[1] if recurrent_net: length = t_end - t_start if length.any_compare(max_seq_length, (lambda a, b: a > b)): continue if length.any_compare(min_seq_length, (lambda a, b: a < b)): continue if length.max_value() > batch_size: print( "warning: sequence length (%i) larger than limit (%i)" % (length.max_value(), batch_size), file=log.v4) if self.rnd_seq_drop.random() < seq_drop: continue dt, ds = batch.try_sequence_as_slice(length) if ds > 1 and ((dt * ds).max_value() > batch_size or ds > max_seqs): yield batch batch = Batch() batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length) else: # Not recurrent. while t_start.max_value() < t_end.max_value(): length = t_end - t_start num_frames = NumbersDict.min([ length, batch_size - batch.get_all_slices_num_frames() ]) assert num_frames.max_value() > 0 batch.add_frames(seq_idx=seq_idx, seq_start_frame=t_start, length=num_frames) if batch.get_all_slices_num_frames( ) >= batch_size or batch.get_num_seqs() > max_seqs: yield batch batch = Batch() t_start += num_frames if seq_idx != last_seq_idx: last_seq_idx = seq_idx total_num_seqs += 1 if batch.get_all_slices_num_frames() > 0: yield batch
def iterate_seqs(self, chunk_size=None, chunk_step=None, used_data_keys=None): """ Takes chunking into consideration. :param int|NumbersDict chunk_size: :param int|NumbersDict chunk_step: :param set(str)|None used_data_keys: :return: generator which yields tuples (seq index, seq start, seq end) :rtype: list[(int,NumbersDict,NumbersDict)] """ if chunk_size is None: chunk_size = self.chunk_size if chunk_step is None: chunk_step = self.chunk_step chunk_size = NumbersDict(chunk_size) chunk_step = NumbersDict(chunk_step) s = 0 while self.is_less_than_num_seqs(s): length = self.get_seq_length(s) if chunk_size == 0: yield (s, NumbersDict.constant_like(0, numbers_dict=length), length) else: default_key = "data" if used_data_keys is not None: length = NumbersDict( {k: length[k] for k in used_data_keys}) if default_key not in used_data_keys: default_key = sorted(used_data_keys)[0] if chunk_step[ default_key] == 0: # allow some keys with zero chunk-step assert chunk_step.max_value() > 0 default_key = [ key for key in sorted(used_data_keys) if chunk_step[key] > 0 ][0] assert chunk_step[default_key] > 0 t = NumbersDict.constant_like(0, numbers_dict=length) # There are usually the 'data' (input) and 'classes' (targets) data-keys in `length` but there can be others. # We expect them all of the same length so that we can do chunking. # In case that some length is 0 or 1, # we treat it special and always return the full seq repeated for every chunk. keys_with_full_seqs = [] for key in length.keys(): if chunk_step[key] == chunk_step[default_key]: if length[key] == length[default_key]: continue # ok if length[key] <= 1: # special case as explained above keys_with_full_seqs.append(key) continue if chunk_step[key] == chunk_step[default_key]: raise Exception( "Chunking with multiple data-keys of different length: %r" % length) else: nr_of_full_chunks_key = (length[key] - chunk_size[key] ) // chunk_step[key] + 1 nr_of_full_chunks_default_key = ( length[default_key] - chunk_size[default_key] ) // chunk_step[default_key] + 1 assert nr_of_full_chunks_key == nr_of_full_chunks_default_key while length[default_key] > t[default_key]: chunk_start = NumbersDict(t) chunk_end = NumbersDict.min([t + chunk_size, length]) for key in keys_with_full_seqs: chunk_start[key] = 0 chunk_end[key] = length[key] if length.value is None: chunk_start.value = None chunk_end.value = None yield (s, chunk_start, chunk_end) t += chunk_step if length[default_key] - t[ default_key] <= self.min_chunk_size: break s += 1
def _generate_batches(self, recurrent_net, batch_size, max_seqs=-1, max_seq_length=sys.maxsize, min_seq_length=0, pruning=0.0, seq_drop=0.0, max_total_num_seqs=-1, used_data_keys=None): """ :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1. Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated. :param int|dict[str,int]|NumbersDict batch_size: Max number of frames in one batch. :param int max_seqs: Max number of seqs per batch. :param int max_total_num_seqs: :param int|dict[str,int]|NumbersDict max_seq_length: :param set(str)|None used_data_keys: """ if not batch_size: batch_size = sys.maxsize batch_size = NumbersDict(batch_size) assert not batch_size.any_compare(NumbersDict(0), (lambda a, b: a <= b)) if max_seqs == -1: max_seqs = float('inf') if not max_seq_length: max_seq_length = sys.maxsize if isinstance(max_seq_length, int) and max_seq_length < 0: max_seq_length = {"classes": -max_seq_length} max_seq_length = NumbersDict(max_seq_length) min_seq_length = NumbersDict(min_seq_length) assert max_seqs > 0 assert seq_drop <= 1.0 if not max_total_num_seqs or max_total_num_seqs < 0: max_total_num_seqs = float("inf") chunk_size = self.chunk_size chunk_step = self.chunk_step if not recurrent_net: if chunk_size != 0: print("Non-recurrent network, chunk size %s:%s ignored" % (chunk_size, chunk_step), file=log.v4) chunk_size = 0 batch = Batch() ctx_lr = self._get_context_window_left_right() total_num_seqs = 0 last_seq_idx = -1 avg_weight = sum([v[0] for v in self.weights.values()]) / (len(self.weights.keys()) or 1) for idx in self.weights: self.weights[idx][1] = random() * avg_weight * pruning self.weights[idx][0] *= (1. + pruning) for seq_idx, t_start, t_end in self.iterate_seqs( chunk_size=chunk_size, chunk_step=chunk_step, used_data_keys=used_data_keys): if not self.sample(seq_idx): continue if total_num_seqs > max_total_num_seqs: break if ctx_lr: t_start -= ctx_lr[0] t_end += ctx_lr[1] if recurrent_net: length = t_end - t_start if length.any_compare(max_seq_length, (lambda a, b: a > b)): continue if length.any_compare(min_seq_length, (lambda a, b: a < b)): continue if length.any_compare(batch_size, (lambda a, b: a > b)): print("warning: sequence length (%r) larger than limit (%r)" % (length, batch_size), file=log.v4) if self.rnd_seq_drop.random() < seq_drop: continue dt, ds = batch.try_sequence_as_slice(length) if ds > 1 and ((dt * ds).any_compare(batch_size, (lambda a, b: a > b)) or ds > max_seqs): yield batch batch = Batch() batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length) else: # Not recurrent. while t_start.max_value() < t_end.max_value(): length = t_end - t_start num_frames = NumbersDict.min( [length, batch_size.copy_like(length) - batch.get_all_slices_num_frames().copy_like(length)]) assert num_frames.max_value() > 0 batch.add_frames(seq_idx=seq_idx, seq_start_frame=t_start, length=num_frames) if (batch.get_all_slices_num_frames().any_compare(batch_size, (lambda a, b: a >= b)) or batch.get_num_seqs() > max_seqs): yield batch batch = Batch() t_start += num_frames if seq_idx != last_seq_idx: last_seq_idx = seq_idx total_num_seqs += 1 if batch.get_all_slices_num_frames().max_value() > 0: yield batch
def iterate_seqs(self, chunk_size=None, chunk_step=None, used_data_keys=None): """ Takes chunking into consideration. :param int|NumbersDict chunk_size: :param int|NumbersDict chunk_step: :param set(str)|None used_data_keys: :return: generator which yields tuples (seq index, seq start, seq end) :rtype: list[(int,NumbersDict,NumbersDict)] """ if chunk_size is None: chunk_size = self.chunk_size if chunk_step is None: chunk_step = self.chunk_step chunk_size = NumbersDict(chunk_size) chunk_step = NumbersDict(chunk_step) s = 0 while self.is_less_than_num_seqs(s): length = self.get_seq_length(s) if chunk_size == 0: yield (s, NumbersDict.constant_like(0, numbers_dict=length), length) else: default_key = "data" if used_data_keys is not None: length = NumbersDict({k: length[k] for k in used_data_keys}) if default_key not in used_data_keys: default_key = sorted(used_data_keys)[0] if chunk_step[default_key] == 0: # allow some keys with zero chunk-step assert chunk_step.max_value() > 0 default_key = [key for key in sorted(used_data_keys) if chunk_step[key] > 0][0] assert chunk_step[default_key] > 0 t = NumbersDict.constant_like(0, numbers_dict=length) # There are usually the 'data' (input) and 'classes' (targets) data-keys in `length` but there can be others. # We expect them all of the same length so that we can do chunking. # In case that some length is 0 or 1, # we treat it special and always return the full seq repeated for every chunk. keys_with_full_seqs = [] for key in length.keys(): if chunk_step[key] == chunk_step[default_key]: if length[key] == length[default_key]: continue # ok if length[key] <= 1: # special case as explained above keys_with_full_seqs.append(key) continue if chunk_step[key] == chunk_step[default_key]: raise Exception("Chunking with multiple data-keys of different length: %r" % length) else: nr_of_full_chunks_key = (length[key] - chunk_size[key]) // chunk_step[key] + 1 nr_of_full_chunks_default_key = ( (length[default_key] - chunk_size[default_key]) // chunk_step[default_key] + 1) assert nr_of_full_chunks_key == nr_of_full_chunks_default_key while length[default_key] > t[default_key]: chunk_start = NumbersDict(t) chunk_end = NumbersDict.min([t + chunk_size, length]) for key in keys_with_full_seqs: chunk_start[key] = 0 chunk_end[key] = length[key] if length.value is None: chunk_start.value = None chunk_end.value = None yield (s, chunk_start, chunk_end) t += chunk_step if length[default_key] - t[default_key] <= self.min_chunk_size: break s += 1