コード例 #1
0
ファイル: SprintErrorSignals.py プロジェクト: chagge/returnn
  def get_batch_loss_and_error_signal(self, target, log_posteriors, seq_lengths):
    """
    :param str target: e.g. "classes". not yet passed over to Sprint.
    :param numpy.ndarray log_posteriors: 3d (time,batch,label)
    :param numpy.ndarray seq_lengths: 1d (batch)
    :rtype (numpy.ndarray, numpy.ndarray)
    :returns (loss, error_signal). error_signal has the same shape as posteriors.
    loss is a 1d-array (batch).

    Note that this accesses some global references, like global current seg info.
    """
    assert seq_lengths.ndim == 1
    assert log_posteriors.ndim == 3
    n_batch = seq_lengths.shape[0]
    assert n_batch == log_posteriors.shape[1]

    import Device
    index = Device.get_current_seq_index(target)  # (time,batch)
    assert index.ndim == 2
    assert index.shape[1] == n_batch
    assert (numpy.sum(index, axis=0) == seq_lengths).all()
    tags = Device.get_current_seq_tags()
    assert len(tags) == n_batch

    batch_loss = numpy.zeros((n_batch,), dtype="float32")
    batch_error_signal = numpy.zeros_like(log_posteriors, dtype="float32")
    # Very simple parallelism. We must avoid any form of multi-threading
    # because this can be problematic with Theano.
    # See: https://groups.google.com/forum/#!msg/theano-users/Pu4YKlZKwm4/eNcAegzaNeYJ
    # We also try to keep it simple here.
    for bb in range(0, n_batch, self.max_num_instances):
      for i in range(self.max_num_instances):
        b = bb + i
        if b >= n_batch: break
        instance = self._get_instance(i)
        instance.get_loss_and_error_signal__send(
          seg_name=tags[b], seg_len=seq_lengths[b], posteriors=log_posteriors[:seq_lengths[b], b])
      for i in range(self.max_num_instances):
        b = bb + i
        if b >= n_batch: break
        instance = self._get_instance(i)
        seg_name, loss, error_signal = instance.get_loss_and_error_signal__read()
        assert seg_name == tags[b]
        batch_loss[b] = loss
        batch_error_signal[:seq_lengths[b], b] = error_signal
    return batch_loss, batch_error_signal