コード例 #1
0
ファイル: MetaDataset.py プロジェクト: fotonower/returnn
    def _generate_batches(self,
                          recurrent_net,
                          batch_size,
                          max_seqs=-1,
                          seq_drop=0.0,
                          max_seq_length=None,
                          used_data_keys=None):
        import sys
        if max_seq_length is None: max_seq_length = sys.maxsize
        if batch_size == 0: batch_size = sys.maxsize
        assert batch_size > 0
        if max_seqs == -1: max_seqs = float('inf')
        assert max_seqs > 0
        assert seq_drop <= 1.0
        chunk_size = self.chunk_size
        chunk_step = self.chunk_step
        from EngineBatch import Batch
        batch = Batch()
        last_seq_idx = None
        for seq_idx, t_start, t_end in self.iterate_seqs(
                chunk_size=chunk_size,
                chunk_step=chunk_step,
                used_data_keys=used_data_keys):
            if self.single_cluster:
                if last_seq_idx is not None and last_seq_idx != seq_idx:
                    last_seq_name = self.get_tag(last_seq_idx)
                    seq_name = self.get_tag(seq_idx)
                    if self.cluster_map[last_seq_name] != self.cluster_map[
                            seq_name]:
                        print("ClusteringDataset::_generate_batches",
                              last_seq_idx,
                              "is not",
                              seq_idx,
                              file=log.v5)
                        yield batch
                        batch = Batch()
            length = t_end - t_start
            if max_seq_length < 0 and length['classes'] > -max_seq_length:
                continue
            elif max_seq_length > 0 and length.max_value() > max_seq_length:
                continue
            if length.max_value() > batch_size:
                print("warning: sequence length (%i) larger than limit (%i)" %
                      (length.max_value(), batch_size),
                      file=log.v4)
            if self.rnd_seq_drop.random() < seq_drop:
                continue
            dt, ds = batch.try_sequence_as_slice(length)
            if ds > 1 and ((dt * ds).max_value() > batch_size
                           or ds > max_seqs):
                yield batch
                batch = Batch()
            print("batch add slice length", length, file=log.v5)
            batch.add_sequence_as_slice(seq_idx=seq_idx,
                                        seq_start_frame=t_start,
                                        length=length)
            last_seq_idx = seq_idx

        if batch.get_all_slices_num_frames() > 0:
            yield batch
コード例 #2
0
    def forward_fill_queue(self):
        """
    Full sequence forwarding, no chunking (at the moment).
    """
        assert self.train_started
        if self.is_forwarding_finished: return

        # We will ignore max_seq_length.
        batch_size = self.config.int('batch_size', 1)
        max_seqs = self.config.int('max_seqs', -1)
        if max_seqs <= 0: max_seqs = float('inf')
        dataset = self.engine.train_data
        from EngineBatch import Batch

        # Collect all batches.
        forward_batches = []
        ":type: list[EngineBatch.Batch]"
        num_seqs = 0
        while self._device_exec("have_space_in_forward_data_queue",
                                num_seqs=num_seqs):
            # Load next sequence for forwarding, keep all which are still needed for training.
            if not dataset.is_less_than_num_seqs(self.forward_current_seq):
                self.is_forwarding_finished = True
                break
            dataset.load_seqs(self.train_start_seq,
                              self.forward_current_seq + 1)
            seq_len = dataset.get_seq_length(self.forward_current_seq)

            if not forward_batches:
                forward_batches.append(Batch())
            batch = forward_batches[-1]
            dt, ds = batch.try_sequence_as_slice(seq_len)
            if ds > 1 and ((dt * ds).max_value() > batch_size
                           or ds > max_seqs):
                batch = Batch()
                forward_batches.append(batch)
            batch.add_sequence_as_slice(seq_idx=self.forward_current_seq,
                                        seq_start_frame=0,
                                        length=seq_len)
            num_seqs += 1
            self.forward_current_seq += 1

        # Forward the batches.
        from EngineUtil import assign_dev_data
        for batch in forward_batches:
            print >> log.v4, "SeqTrainParallelControl, forward %r" % batch
            success = assign_dev_data(self.train_device,
                                      dataset, [batch],
                                      load_seqs=False)
            assert success, "failed to allocate & assign data"
            self.train_device.update_data()
            self._device_exec("do_forward", batch=batch)
            self._device_exec("train_check_calc_loss")
コード例 #3
0
def assign_dev_data_single_seq(device, dataset, seq, load_seqs=True):
    """
  :type device: Device.Device
  :type dataset: Dataset.Dataset
  :param int seq: sorted seq idx
  :return: whether we succeeded
  :rtype: bool
  """
    batch = Batch()
    batch.init_with_one_full_sequence(seq_idx=seq, dataset=dataset)
    success, _ = assign_dev_data(device, dataset, [batch], load_seqs=load_seqs)
    return success
コード例 #4
0
ファイル: EngineUtil.py プロジェクト: atuxhe/returnn
def assign_dev_data_single_seq(device, dataset, seq, load_seqs=True):
  """
  :type device: Device.Device
  :type dataset: Dataset.Dataset
  :param int seq: sorted seq idx
  :return: whether we succeeded
  :rtype: bool
  """
  batch = Batch()
  batch.add_frames(seq_idx=seq, seq_start_frame=0, length=dataset.get_seq_length(seq))
  success, _ = assign_dev_data(device, dataset, [batch], load_seqs=load_seqs)
  return success
コード例 #5
0
ファイル: EngineUtil.py プロジェクト: rwth-i6/returnn
def assign_dev_data_single_seq(device, dataset, seq, load_seqs=True):
  """
  :type device: Device.Device
  :type dataset: Dataset.Dataset
  :param int seq: sorted seq idx
  :param bool load_seqs:
  :return: whether we succeeded
  :rtype: bool
  """
  batch = Batch()
  batch.init_with_one_full_sequence(seq_idx=seq, dataset=dataset)
  success, _ = assign_dev_data(device, dataset, [batch], load_seqs=load_seqs)
  return success
コード例 #6
0
def assign_dev_data_single_seq(device, dataset, seq, load_seqs=True):
    """
  :type device: Device.Device
  :type dataset: Dataset.Dataset
  :param int seq: sorted seq idx
  :return: whether we succeeded
  :rtype: bool
  """
    batch = Batch()
    batch.add_frames(seq_idx=seq,
                     seq_start_frame=0,
                     length=dataset.get_seq_length(seq))
    success, _ = assign_dev_data(device, dataset, [batch], load_seqs=load_seqs)
    return success
コード例 #7
0
ファイル: Engine.py プロジェクト: atuxhe/returnn
  def forward_fill_queue(self):
    """
    Full sequence forwarding, no chunking (at the moment).
    """
    assert self.train_started
    if self.is_forwarding_finished: return

    # We will ignore max_seq_length.
    batch_size = self.config.int('batch_size', 1)
    max_seqs = self.config.int('max_seqs', -1)
    if max_seqs <= 0: max_seqs = float('inf')
    dataset = self.engine.train_data
    from EngineBatch import Batch

    # Collect all batches.
    forward_batches = []; ":type: list[EngineBatch.Batch]"
    num_seqs = 0
    while self._device_exec("have_space_in_forward_data_queue", num_seqs=num_seqs):
      # Load next sequence for forwarding, keep all which are still needed for training.
      if not dataset.is_less_than_num_seqs(self.forward_current_seq):
        self.is_forwarding_finished = True
        break
      dataset.load_seqs(self.train_start_seq, self.forward_current_seq + 1)
      seq_len = dataset.get_seq_length(self.forward_current_seq)

      if not forward_batches:
        forward_batches.append(Batch())
      batch = forward_batches[-1]
      dt, ds = batch.try_sequence_as_slice(seq_len)
      if ds > 1 and ((dt * ds).max_value() > batch_size or ds > max_seqs):
        batch = Batch()
        forward_batches.append(batch)
      batch.add_sequence_as_slice(seq_idx=self.forward_current_seq, seq_start_frame=0, length=seq_len)
      num_seqs += 1
      self.forward_current_seq += 1

    # Forward the batches.
    from EngineUtil import assign_dev_data
    for batch in forward_batches:
      print >> log.v4, "SeqTrainParallelControl, forward %r" % batch
      success = assign_dev_data(self.train_device, dataset, [batch], load_seqs=False)
      assert success, "failed to allocate & assign data"
      self.train_device.update_data()
      self._device_exec("do_forward", batch=batch)
      self._device_exec("train_check_calc_loss")
コード例 #8
0
ファイル: Engine.py プロジェクト: wbengine/returnn
  def forward_single(self, dataset, seq_idx, output_layer_name=None):
    """
    Forwards a single sequence.
    If you want to perform search, and get a number of hyps out, use :func:`search_single`.

    :param Dataset.Dataset dataset:
    :param int seq_idx:
    :param str|None output_layer_name: e.g. "output". if not set, will read from config "forward_output_layer"
    :return: numpy array, output in time major format (time,dim)
    :rtype: numpy.ndarray
    """

    batch = Batch()
    batch.init_with_one_full_sequence(seq_idx=seq_idx, dataset=dataset)
    batch_generator = iter([batch])
    batches = BatchSetGenerator(dataset, generator=batch_generator)

    forwarder = ClassificationTaskThread(self.network, self.devices, dataset, batches)
    forwarder.join()
    assert forwarder.output.shape[1] == 1
    return forwarder.output[:, 0]
コード例 #9
0
ファイル: Engine.py プロジェクト: rwth-i6/returnn
  def forward_single(self, dataset, seq_idx, output_layer_name=None):
    """
    Forwards a single sequence.
    If you want to perform search, and get a number of hyps out, use :func:`search_single`.

    :param Dataset.Dataset dataset:
    :param int seq_idx:
    :param str|None output_layer_name: e.g. "output". if not set, will read from config "forward_output_layer"
    :return: numpy array, output in time major format (time,dim)
    :rtype: numpy.ndarray
    """
    from EngineBatch import Batch, BatchSetGenerator
    batch = Batch()
    batch.init_with_one_full_sequence(seq_idx=seq_idx, dataset=dataset)
    batch_generator = iter([batch])
    batches = BatchSetGenerator(dataset, generator=batch_generator)

    forwarder = ClassificationTaskThread(self.network, self.devices, dataset, batches)
    forwarder.join()
    assert forwarder.output.shape[1] == 1
    return forwarder.output[:, 0]
コード例 #10
0
ファイル: Dataset.py プロジェクト: ZhangAustin/returnn
    def _generate_batches(self,
                          recurrent_net,
                          batch_size,
                          max_seqs=-1,
                          seq_drop=0.0,
                          max_seq_length=sys.maxsize,
                          used_data_keys=None):
        """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    :param set(str)|None used_data_keys:
    """
        if batch_size == 0: batch_size = sys.maxsize
        assert batch_size > 0
        if max_seqs == -1: max_seqs = float('inf')
        assert max_seqs > 0
        assert seq_drop <= 1.0
        chunk_size = self.chunk_size
        chunk_step = self.chunk_step
        if not recurrent_net:
            if chunk_size != 0:
                print("Non-recurrent network, chunk size %i:%i ignored" %
                      (chunk_size, chunk_step),
                      file=log.v4)
                chunk_size = 0
        batch = Batch()
        for seq_idx, t_start, t_end in self._iterate_seqs(
                chunk_size=chunk_size,
                chunk_step=chunk_step,
                used_data_keys=used_data_keys):
            if recurrent_net:
                length = t_end - t_start
                if max_seq_length < 0 and length['classes'] > -max_seq_length:
                    continue
                elif max_seq_length > 0 and length.max_value(
                ) > max_seq_length:
                    continue
                if length.max_value() > batch_size:
                    print(
                        "warning: sequence length (%i) larger than limit (%i)"
                        % (length.max_value(), batch_size),
                        file=log.v4)
                if self.rnd_seq_drop.random() < seq_drop:
                    continue
                dt, ds = batch.try_sequence_as_slice(length)
                if ds > 1 and ((dt * ds).max_value() > batch_size
                               or ds > max_seqs):
                    yield batch
                    batch = Batch()
                batch.add_sequence_as_slice(seq_idx=seq_idx,
                                            seq_start_frame=t_start,
                                            length=length)
            else:  # Not recurrent.
                while t_start.max_value() < t_end.max_value():
                    length = t_end - t_start
                    num_frames = NumbersDict.min([
                        length, batch_size - batch.get_all_slices_num_frames()
                    ])
                    assert num_frames.max_value() > 0
                    batch.add_frames(seq_idx=seq_idx,
                                     seq_start_frame=t_start,
                                     length=num_frames)
                    if batch.get_all_slices_num_frames(
                    ) >= batch_size or batch.get_num_seqs() > max_seqs:
                        yield batch
                        batch = Batch()
                    t_start += num_frames

        if batch.get_all_slices_num_frames() > 0:
            yield batch
コード例 #11
0
def generate_batch(seq_idx, dataset):
    batch = Batch()
    batch.add_frames(seq_idx=seq_idx,
                     seq_start_frame=0,
                     length=dataset.get_seq_length(seq_idx))
    return batch
コード例 #12
0
ファイル: test_EngineUtil.py プロジェクト: rwth-i6/returnn
def generate_batch(seq_idx, dataset):
  batch = Batch()
  batch.add_frames(seq_idx=seq_idx, seq_start_frame=0, length=dataset.get_seq_length(seq_idx))
  return batch
コード例 #13
0
    def _generate_batches(self,
                          recurrent_net,
                          batch_size,
                          max_seqs=-1,
                          max_seq_length=sys.maxsize,
                          min_seq_length=0,
                          seq_drop=0.0,
                          max_total_num_seqs=-1,
                          used_data_keys=None):
        """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    :param int max_total_num_seqs:
    :param int|dict[str,int]|NumbersDict max_seq_length:
    :param set(str)|None used_data_keys:
    """
        if batch_size == 0:
            batch_size = sys.maxsize
        assert batch_size > 0
        if max_seqs == -1:
            max_seqs = float('inf')
        if not max_seq_length:
            max_seq_length = sys.maxsize
        if isinstance(max_seq_length, int) and max_seq_length < 0:
            max_seq_length = {"classes": -max_seq_length}
        max_seq_length = NumbersDict(max_seq_length)
        min_seq_length = NumbersDict(min_seq_length)
        assert max_seqs > 0
        assert seq_drop <= 1.0
        if not max_total_num_seqs or max_total_num_seqs < 0:
            max_total_num_seqs = float("inf")
        chunk_size = self.chunk_size
        chunk_step = self.chunk_step
        if not recurrent_net:
            if chunk_size != 0:
                print("Non-recurrent network, chunk size %s:%s ignored" %
                      (chunk_size, chunk_step),
                      file=log.v4)
                chunk_size = 0
        batch = Batch()
        ctx_lr = self._get_context_window_left_right()
        total_num_seqs = 0
        last_seq_idx = -1
        for seq_idx, t_start, t_end in self.iterate_seqs(
                chunk_size=chunk_size,
                chunk_step=chunk_step,
                used_data_keys=used_data_keys):
            if seq_idx != last_seq_idx and total_num_seqs >= max_total_num_seqs:
                break
            if ctx_lr:
                t_start -= ctx_lr[0]
                t_end += ctx_lr[1]
            if recurrent_net:
                length = t_end - t_start
                if length.any_compare(max_seq_length, (lambda a, b: a > b)):
                    continue
                if length.any_compare(min_seq_length, (lambda a, b: a < b)):
                    continue
                if length.max_value() > batch_size:
                    print(
                        "warning: sequence length (%i) larger than limit (%i)"
                        % (length.max_value(), batch_size),
                        file=log.v4)
                if self.rnd_seq_drop.random() < seq_drop:
                    continue
                dt, ds = batch.try_sequence_as_slice(length)
                if ds > 1 and ((dt * ds).max_value() > batch_size
                               or ds > max_seqs):
                    yield batch
                    batch = Batch()
                batch.add_sequence_as_slice(seq_idx=seq_idx,
                                            seq_start_frame=t_start,
                                            length=length)
            else:  # Not recurrent.
                while t_start.max_value() < t_end.max_value():
                    length = t_end - t_start
                    num_frames = NumbersDict.min([
                        length, batch_size - batch.get_all_slices_num_frames()
                    ])
                    assert num_frames.max_value() > 0
                    batch.add_frames(seq_idx=seq_idx,
                                     seq_start_frame=t_start,
                                     length=num_frames)
                    if batch.get_all_slices_num_frames(
                    ) >= batch_size or batch.get_num_seqs() > max_seqs:
                        yield batch
                        batch = Batch()
                    t_start += num_frames
            if seq_idx != last_seq_idx:
                last_seq_idx = seq_idx
                total_num_seqs += 1

        if batch.get_all_slices_num_frames() > 0:
            yield batch
コード例 #14
0
ファイル: Dataset.py プロジェクト: atuxhe/returnn
  def _generate_batches(self, recurrent_net, batch_size, max_seqs=-1, batch_variance=0.0, max_seq_length=sys.maxsize):
    """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    """
    if batch_size == 0: batch_size = sys.maxsize
    assert batch_size > 0
    ms = max_seqs
    bs = batch_size
    if max_seqs == -1: max_seqs = float('inf')
    assert max_seqs > 0
    chunk_size = self.chunk_size
    chunk_step = self.chunk_step
    if not recurrent_net:
      if chunk_size != 0:
        print >> log.v4, "Non-recurrent network, chunk size %i:%i ignored" % (chunk_size, chunk_step)
        chunk_size = 0

    assert batch_variance <= 1.0
    if batch_variance > 0.0:
      r = (1.0 - self.rnd_batch_variance.random() * batch_variance)
      if max_seqs > 0:
        max_seqs = max(int(r * ms), 1)
      #if batch_size > 0:
      #  batch_size = max(int(r * bs), 1)

    batch = Batch()
    for seq_idx, t_start, t_end in self._iterate_seqs(chunk_size=chunk_size, chunk_step=chunk_step):
      if recurrent_net:
        length = t_end - t_start
        if length.max_value() > max_seq_length:
          continue
        if length.max_value() > batch_size:
          print >> log.v4, "warning: sequence length (%i) larger than limit (%i)" % (length.max_value(), batch_size)
        dt, ds = batch.try_sequence_as_slice(length)
        if ds > 1 and ((dt * ds).max_value() > batch_size or ds > max_seqs):
          yield batch
          batch = Batch()
          if batch_variance > 0.0:
            r = (1.0 - self.rnd_batch_variance.random() * batch_variance)
            if max_seqs > 0:
              max_seqs = max(int(r * ms), 1)
            #if batch_size > 0:
            #  batch_size = max(int(r * bs), 1)
        batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length)
      else:  # Not recurrent.
        while t_start.max_value() < t_end.max_value():
          length = t_end - t_start
          num_frames = min(length.max_value(), batch_size - batch.get_all_slices_num_frames())
          assert num_frames > 0
          batch.add_frames(seq_idx=seq_idx, seq_start_frame=t_start, length=num_frames)
          if batch.get_all_slices_num_frames() >= batch_size or batch.get_num_seqs() > max_seqs:
            yield batch
            batch = Batch()
          t_start += num_frames

    if batch.get_all_slices_num_frames() > 0:
      yield batch
コード例 #15
0
ファイル: Dataset.py プロジェクト: rwth-i6/returnn
  def _generate_batches(self, recurrent_net,
                        batch_size, max_seqs=-1, max_seq_length=sys.maxsize,
                        min_seq_length=0, pruning=0.0,
                        seq_drop=0.0, max_total_num_seqs=-1,
                        used_data_keys=None):
    """
    :param bool recurrent_net: If True, the batch might have a batch seq dimension > 1.
      Otherwise, the batch seq dimension is always 1 and multiple seqs will be concatenated.
    :param int|dict[str,int]|NumbersDict batch_size: Max number of frames in one batch.
    :param int max_seqs: Max number of seqs per batch.
    :param int max_total_num_seqs:
    :param int|dict[str,int]|NumbersDict max_seq_length:
    :param set(str)|None used_data_keys:
    """
    if not batch_size:
      batch_size = sys.maxsize
    batch_size = NumbersDict(batch_size)
    assert not batch_size.any_compare(NumbersDict(0), (lambda a, b: a <= b))
    if max_seqs == -1:
      max_seqs = float('inf')
    if not max_seq_length:
      max_seq_length = sys.maxsize
    if isinstance(max_seq_length, int) and max_seq_length < 0:
      max_seq_length = {"classes": -max_seq_length}
    max_seq_length = NumbersDict(max_seq_length)
    min_seq_length = NumbersDict(min_seq_length)
    assert max_seqs > 0
    assert seq_drop <= 1.0
    if not max_total_num_seqs or max_total_num_seqs < 0:
      max_total_num_seqs = float("inf")
    chunk_size = self.chunk_size
    chunk_step = self.chunk_step
    if not recurrent_net:
      if chunk_size != 0:
        print("Non-recurrent network, chunk size %s:%s ignored" % (chunk_size, chunk_step), file=log.v4)
        chunk_size = 0
    batch = Batch()
    ctx_lr = self._get_context_window_left_right()
    total_num_seqs = 0
    last_seq_idx = -1
    avg_weight = sum([v[0] for v in self.weights.values()]) / (len(self.weights.keys()) or 1)
    for idx in self.weights:
      self.weights[idx][1] = random() * avg_weight * pruning
      self.weights[idx][0] *= (1. + pruning)
    for seq_idx, t_start, t_end in self.iterate_seqs(
          chunk_size=chunk_size, chunk_step=chunk_step, used_data_keys=used_data_keys):
      if not self.sample(seq_idx):
        continue
      if total_num_seqs > max_total_num_seqs:
        break
      if ctx_lr:
        t_start -= ctx_lr[0]
        t_end += ctx_lr[1]
      if recurrent_net:
        length = t_end - t_start
        if length.any_compare(max_seq_length, (lambda a, b: a > b)):
          continue
        if length.any_compare(min_seq_length, (lambda a, b: a < b)):
          continue
        if length.any_compare(batch_size, (lambda a, b: a > b)):
          print("warning: sequence length (%r) larger than limit (%r)" % (length, batch_size), file=log.v4)
        if self.rnd_seq_drop.random() < seq_drop:
          continue
        dt, ds = batch.try_sequence_as_slice(length)
        if ds > 1 and ((dt * ds).any_compare(batch_size, (lambda a, b: a > b)) or ds > max_seqs):
          yield batch
          batch = Batch()
        batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length)
      else:  # Not recurrent.
        while t_start.max_value() < t_end.max_value():
          length = t_end - t_start
          num_frames = NumbersDict.min(
            [length, batch_size.copy_like(length) - batch.get_all_slices_num_frames().copy_like(length)])
          assert num_frames.max_value() > 0
          batch.add_frames(seq_idx=seq_idx, seq_start_frame=t_start, length=num_frames)
          if (batch.get_all_slices_num_frames().any_compare(batch_size, (lambda a, b: a >= b))
                  or batch.get_num_seqs() > max_seqs):
            yield batch
            batch = Batch()
          t_start += num_frames
      if seq_idx != last_seq_idx:
        last_seq_idx = seq_idx
        total_num_seqs += 1

    if batch.get_all_slices_num_frames().max_value() > 0:
      yield batch