Exemple #1
0
 def iterate_seqs(self,
                  chunk_size=None,
                  chunk_step=None,
                  used_data_keys=None):
     """
 Takes chunking into consideration.
 :param int chunk_size:
 :param int chunk_step:
 :param set(str)|None used_data_keys:
 :return: generator which yields tuples (seq index, seq start, seq end)
 :rtype: list[(int,NumbersDict,NumbersDict)]
 """
     if chunk_size is None:
         chunk_size = self.chunk_size
     if chunk_step is None:
         chunk_step = self.chunk_step
     s = 0
     while self.is_less_than_num_seqs(s):
         length = self.get_seq_length(s)
         if chunk_size == 0:
             yield (s, length.constant_like(0), length)
         else:
             if used_data_keys is not None:
                 length = NumbersDict(
                     {k: length[k]
                      for k in used_data_keys})
             t = length.constant_like(0)
             default_key = "data"
             # There are usually the 'data' (input) and 'classes' (targets) data-keys in `length` but there can be others.
             # We expect them all of the same length so that we can do chunking.
             # In case that some length is 0 or 1,
             # we treat it special and always return the full seq repeated for every chunk.
             keys_with_full_seqs = []
             for key in length.keys():
                 if length[key] == length[default_key]:
                     continue  # ok
                 if length[key] <= 1:
                     keys_with_full_seqs.append(key)
                     continue
                 raise Exception(
                     "Chunking with multiple data-keys of different length: %r"
                     % length)
             while length[default_key] > t[default_key]:
                 chunk_start = NumbersDict(t)
                 chunk_end = NumbersDict.min([t + chunk_size, length])
                 for key in keys_with_full_seqs:
                     chunk_start[key] = 0
                     chunk_end[key] = length[key]
                 if length.value is None:
                     chunk_start.value = None
                     chunk_end.value = None
                 yield (s, chunk_start, chunk_end)
                 t += chunk_step
                 if length[default_key] - t[
                         default_key] <= self.min_chunk_size:
                     break
         s += 1
Exemple #2
0
    def _generate_batches(self,
                          recurrent_net,
                          batch_size,
                          max_seqs=-1,
                          seq_drop=0.0,
                          max_seq_length=sys.maxsize,
                          used_data_keys=None):
        """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    :param set(str)|None used_data_keys:
    """
        if batch_size == 0: batch_size = sys.maxsize
        assert batch_size > 0
        if max_seqs == -1: max_seqs = float('inf')
        assert max_seqs > 0
        assert seq_drop <= 1.0
        chunk_size = self.chunk_size
        chunk_step = self.chunk_step
        if not recurrent_net:
            if chunk_size != 0:
                print("Non-recurrent network, chunk size %i:%i ignored" %
                      (chunk_size, chunk_step),
                      file=log.v4)
                chunk_size = 0
        batch = Batch()
        for seq_idx, t_start, t_end in self._iterate_seqs(
                chunk_size=chunk_size,
                chunk_step=chunk_step,
                used_data_keys=used_data_keys):
            if recurrent_net:
                length = t_end - t_start
                if max_seq_length < 0 and length['classes'] > -max_seq_length:
                    continue
                elif max_seq_length > 0 and length.max_value(
                ) > max_seq_length:
                    continue
                if length.max_value() > batch_size:
                    print(
                        "warning: sequence length (%i) larger than limit (%i)"
                        % (length.max_value(), batch_size),
                        file=log.v4)
                if self.rnd_seq_drop.random() < seq_drop:
                    continue
                dt, ds = batch.try_sequence_as_slice(length)
                if ds > 1 and ((dt * ds).max_value() > batch_size
                               or ds > max_seqs):
                    yield batch
                    batch = Batch()
                batch.add_sequence_as_slice(seq_idx=seq_idx,
                                            seq_start_frame=t_start,
                                            length=length)
            else:  # Not recurrent.
                while t_start.max_value() < t_end.max_value():
                    length = t_end - t_start
                    num_frames = NumbersDict.min([
                        length, batch_size - batch.get_all_slices_num_frames()
                    ])
                    assert num_frames.max_value() > 0
                    batch.add_frames(seq_idx=seq_idx,
                                     seq_start_frame=t_start,
                                     length=num_frames)
                    if batch.get_all_slices_num_frames(
                    ) >= batch_size or batch.get_num_seqs() > max_seqs:
                        yield batch
                        batch = Batch()
                    t_start += num_frames

        if batch.get_all_slices_num_frames() > 0:
            yield batch
Exemple #3
0
    def _generate_batches(self,
                          recurrent_net,
                          batch_size,
                          max_seqs=-1,
                          max_seq_length=sys.maxsize,
                          min_seq_length=0,
                          seq_drop=0.0,
                          max_total_num_seqs=-1,
                          used_data_keys=None):
        """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    :param int max_total_num_seqs:
    :param int|dict[str,int]|NumbersDict max_seq_length:
    :param set(str)|None used_data_keys:
    """
        if batch_size == 0:
            batch_size = sys.maxsize
        assert batch_size > 0
        if max_seqs == -1:
            max_seqs = float('inf')
        if not max_seq_length:
            max_seq_length = sys.maxsize
        if isinstance(max_seq_length, int) and max_seq_length < 0:
            max_seq_length = {"classes": -max_seq_length}
        max_seq_length = NumbersDict(max_seq_length)
        min_seq_length = NumbersDict(min_seq_length)
        assert max_seqs > 0
        assert seq_drop <= 1.0
        if not max_total_num_seqs or max_total_num_seqs < 0:
            max_total_num_seqs = float("inf")
        chunk_size = self.chunk_size
        chunk_step = self.chunk_step
        if not recurrent_net:
            if chunk_size != 0:
                print("Non-recurrent network, chunk size %s:%s ignored" %
                      (chunk_size, chunk_step),
                      file=log.v4)
                chunk_size = 0
        batch = Batch()
        ctx_lr = self._get_context_window_left_right()
        total_num_seqs = 0
        last_seq_idx = -1
        for seq_idx, t_start, t_end in self.iterate_seqs(
                chunk_size=chunk_size,
                chunk_step=chunk_step,
                used_data_keys=used_data_keys):
            if seq_idx != last_seq_idx and total_num_seqs >= max_total_num_seqs:
                break
            if ctx_lr:
                t_start -= ctx_lr[0]
                t_end += ctx_lr[1]
            if recurrent_net:
                length = t_end - t_start
                if length.any_compare(max_seq_length, (lambda a, b: a > b)):
                    continue
                if length.any_compare(min_seq_length, (lambda a, b: a < b)):
                    continue
                if length.max_value() > batch_size:
                    print(
                        "warning: sequence length (%i) larger than limit (%i)"
                        % (length.max_value(), batch_size),
                        file=log.v4)
                if self.rnd_seq_drop.random() < seq_drop:
                    continue
                dt, ds = batch.try_sequence_as_slice(length)
                if ds > 1 and ((dt * ds).max_value() > batch_size
                               or ds > max_seqs):
                    yield batch
                    batch = Batch()
                batch.add_sequence_as_slice(seq_idx=seq_idx,
                                            seq_start_frame=t_start,
                                            length=length)
            else:  # Not recurrent.
                while t_start.max_value() < t_end.max_value():
                    length = t_end - t_start
                    num_frames = NumbersDict.min([
                        length, batch_size - batch.get_all_slices_num_frames()
                    ])
                    assert num_frames.max_value() > 0
                    batch.add_frames(seq_idx=seq_idx,
                                     seq_start_frame=t_start,
                                     length=num_frames)
                    if batch.get_all_slices_num_frames(
                    ) >= batch_size or batch.get_num_seqs() > max_seqs:
                        yield batch
                        batch = Batch()
                    t_start += num_frames
            if seq_idx != last_seq_idx:
                last_seq_idx = seq_idx
                total_num_seqs += 1

        if batch.get_all_slices_num_frames() > 0:
            yield batch
Exemple #4
0
 def iterate_seqs(self,
                  chunk_size=None,
                  chunk_step=None,
                  used_data_keys=None):
     """
 Takes chunking into consideration.
 :param int|NumbersDict chunk_size:
 :param int|NumbersDict chunk_step:
 :param set(str)|None used_data_keys:
 :return: generator which yields tuples (seq index, seq start, seq end)
 :rtype: list[(int,NumbersDict,NumbersDict)]
 """
     if chunk_size is None:
         chunk_size = self.chunk_size
     if chunk_step is None:
         chunk_step = self.chunk_step
     chunk_size = NumbersDict(chunk_size)
     chunk_step = NumbersDict(chunk_step)
     s = 0
     while self.is_less_than_num_seqs(s):
         length = self.get_seq_length(s)
         if chunk_size == 0:
             yield (s, NumbersDict.constant_like(0, numbers_dict=length),
                    length)
         else:
             default_key = "data"
             if used_data_keys is not None:
                 length = NumbersDict(
                     {k: length[k]
                      for k in used_data_keys})
                 if default_key not in used_data_keys:
                     default_key = sorted(used_data_keys)[0]
                 if chunk_step[
                         default_key] == 0:  # allow some keys with zero chunk-step
                     assert chunk_step.max_value() > 0
                     default_key = [
                         key for key in sorted(used_data_keys)
                         if chunk_step[key] > 0
                     ][0]
             assert chunk_step[default_key] > 0
             t = NumbersDict.constant_like(0, numbers_dict=length)
             # There are usually the 'data' (input) and 'classes' (targets) data-keys in `length` but there can be others.
             # We expect them all of the same length so that we can do chunking.
             # In case that some length is 0 or 1,
             # we treat it special and always return the full seq repeated for every chunk.
             keys_with_full_seqs = []
             for key in length.keys():
                 if chunk_step[key] == chunk_step[default_key]:
                     if length[key] == length[default_key]:
                         continue  # ok
                 if length[key] <= 1:  # special case as explained above
                     keys_with_full_seqs.append(key)
                     continue
                 if chunk_step[key] == chunk_step[default_key]:
                     raise Exception(
                         "Chunking with multiple data-keys of different length: %r"
                         % length)
                 else:
                     nr_of_full_chunks_key = (length[key] - chunk_size[key]
                                              ) // chunk_step[key] + 1
                     nr_of_full_chunks_default_key = (
                         length[default_key] - chunk_size[default_key]
                     ) // chunk_step[default_key] + 1
                     assert nr_of_full_chunks_key == nr_of_full_chunks_default_key
             while length[default_key] > t[default_key]:
                 chunk_start = NumbersDict(t)
                 chunk_end = NumbersDict.min([t + chunk_size, length])
                 for key in keys_with_full_seqs:
                     chunk_start[key] = 0
                     chunk_end[key] = length[key]
                 if length.value is None:
                     chunk_start.value = None
                     chunk_end.value = None
                 yield (s, chunk_start, chunk_end)
                 t += chunk_step
                 if length[default_key] - t[
                         default_key] <= self.min_chunk_size:
                     break
         s += 1
Exemple #5
0
  def _generate_batches(self, recurrent_net,
                        batch_size, max_seqs=-1, max_seq_length=sys.maxsize,
                        min_seq_length=0, pruning=0.0,
                        seq_drop=0.0, max_total_num_seqs=-1,
                        used_data_keys=None):
    """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int|dict[str,int]|NumbersDict batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    :param int max_total_num_seqs:
    :param int|dict[str,int]|NumbersDict max_seq_length:
    :param set(str)|None used_data_keys:
    """
    if not batch_size:
      batch_size = sys.maxsize
    batch_size = NumbersDict(batch_size)
    assert not batch_size.any_compare(NumbersDict(0), (lambda a, b: a <= b))
    if max_seqs == -1:
      max_seqs = float('inf')
    if not max_seq_length:
      max_seq_length = sys.maxsize
    if isinstance(max_seq_length, int) and max_seq_length < 0:
      max_seq_length = {"classes": -max_seq_length}
    max_seq_length = NumbersDict(max_seq_length)
    min_seq_length = NumbersDict(min_seq_length)
    assert max_seqs > 0
    assert seq_drop <= 1.0
    if not max_total_num_seqs or max_total_num_seqs < 0:
      max_total_num_seqs = float("inf")
    chunk_size = self.chunk_size
    chunk_step = self.chunk_step
    if not recurrent_net:
      if chunk_size != 0:
        print("Non-recurrent network, chunk size %s:%s ignored" % (chunk_size, chunk_step), file=log.v4)
        chunk_size = 0
    batch = Batch()
    ctx_lr = self._get_context_window_left_right()
    total_num_seqs = 0
    last_seq_idx = -1
    avg_weight = sum([v[0] for v in self.weights.values()]) / (len(self.weights.keys()) or 1)
    for idx in self.weights:
      self.weights[idx][1] = random() * avg_weight * pruning
      self.weights[idx][0] *= (1. + pruning)
    for seq_idx, t_start, t_end in self.iterate_seqs(
          chunk_size=chunk_size, chunk_step=chunk_step, used_data_keys=used_data_keys):
      if not self.sample(seq_idx):
        continue
      if total_num_seqs > max_total_num_seqs:
        break
      if ctx_lr:
        t_start -= ctx_lr[0]
        t_end += ctx_lr[1]
      if recurrent_net:
        length = t_end - t_start
        if length.any_compare(max_seq_length, (lambda a, b: a > b)):
          continue
        if length.any_compare(min_seq_length, (lambda a, b: a < b)):
          continue
        if length.any_compare(batch_size, (lambda a, b: a > b)):
          print("warning: sequence length (%r) larger than limit (%r)" % (length, batch_size), file=log.v4)
        if self.rnd_seq_drop.random() < seq_drop:
          continue
        dt, ds = batch.try_sequence_as_slice(length)
        if ds > 1 and ((dt * ds).any_compare(batch_size, (lambda a, b: a > b)) or ds > max_seqs):
          yield batch
          batch = Batch()
        batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length)
      else:  # Not recurrent.
        while t_start.max_value() < t_end.max_value():
          length = t_end - t_start
          num_frames = NumbersDict.min(
            [length, batch_size.copy_like(length) - batch.get_all_slices_num_frames().copy_like(length)])
          assert num_frames.max_value() > 0
          batch.add_frames(seq_idx=seq_idx, seq_start_frame=t_start, length=num_frames)
          if (batch.get_all_slices_num_frames().any_compare(batch_size, (lambda a, b: a >= b))
                  or batch.get_num_seqs() > max_seqs):
            yield batch
            batch = Batch()
          t_start += num_frames
      if seq_idx != last_seq_idx:
        last_seq_idx = seq_idx
        total_num_seqs += 1

    if batch.get_all_slices_num_frames().max_value() > 0:
      yield batch
Exemple #6
0
 def iterate_seqs(self, chunk_size=None, chunk_step=None, used_data_keys=None):
   """
   Takes chunking into consideration.
   :param int|NumbersDict chunk_size:
   :param int|NumbersDict chunk_step:
   :param set(str)|None used_data_keys:
   :return: generator which yields tuples (seq index, seq start, seq end)
   :rtype: list[(int,NumbersDict,NumbersDict)]
   """
   if chunk_size is None:
     chunk_size = self.chunk_size
   if chunk_step is None:
     chunk_step = self.chunk_step
   chunk_size = NumbersDict(chunk_size)
   chunk_step = NumbersDict(chunk_step)
   s = 0
   while self.is_less_than_num_seqs(s):
     length = self.get_seq_length(s)
     if chunk_size == 0:
       yield (s, NumbersDict.constant_like(0, numbers_dict=length), length)
     else:
       default_key = "data"
       if used_data_keys is not None:
         length = NumbersDict({k: length[k] for k in used_data_keys})
         if default_key not in used_data_keys:
           default_key = sorted(used_data_keys)[0]
         if chunk_step[default_key] == 0:  # allow some keys with zero chunk-step
           assert chunk_step.max_value() > 0
           default_key = [key for key in sorted(used_data_keys) if chunk_step[key] > 0][0]
       assert chunk_step[default_key] > 0
       t = NumbersDict.constant_like(0, numbers_dict=length)
       # There are usually the 'data' (input) and 'classes' (targets) data-keys in `length` but there can be others.
       # We expect them all of the same length so that we can do chunking.
       # In case that some length is 0 or 1,
       # we treat it special and always return the full seq repeated for every chunk.
       keys_with_full_seqs = []
       for key in length.keys():
         if chunk_step[key] == chunk_step[default_key]:
           if length[key] == length[default_key]:
             continue  # ok
         if length[key] <= 1:  # special case as explained above
           keys_with_full_seqs.append(key)
           continue
         if chunk_step[key] == chunk_step[default_key]:
           raise Exception("Chunking with multiple data-keys of different length: %r" % length)
         else:
           nr_of_full_chunks_key = (length[key] - chunk_size[key]) // chunk_step[key] + 1
           nr_of_full_chunks_default_key = (
             (length[default_key] - chunk_size[default_key]) // chunk_step[default_key] + 1)
           assert nr_of_full_chunks_key == nr_of_full_chunks_default_key
       while length[default_key] > t[default_key]:
         chunk_start = NumbersDict(t)
         chunk_end = NumbersDict.min([t + chunk_size, length])
         for key in keys_with_full_seqs:
           chunk_start[key] = 0
           chunk_end[key] = length[key]
         if length.value is None:
           chunk_start.value = None
           chunk_end.value = None
         yield (s, chunk_start, chunk_end)
         t += chunk_step
         if length[default_key] - t[default_key] <= self.min_chunk_size:
           break
     s += 1