예제 #1
0
  def handle_cmd_get_loss_and_error_signal(self, seg_name, seg_len, posteriors):
    """
    :param str seg_name: seg name
    :param int seg_len: the segment length in frames
    :param numpy.ndarray posteriors: 2d (time,label) float array

    See SprintErrorSignals.SprintSubprocessInstance.get_loss_and_error_signal().
    """
    assert isinstance(seg_len, (int, long))
    assert seg_len > 0
    assert posteriors.ndim == 2
    assert posteriors.shape[0] == seg_len
    if Verbose: print("CRNN SprintControl[pid %i] PythonControl handle_cmd_get_loss_and_error_signal: name=%r, len=%r" % (os.getpid(), seg_name, seg_len))
    with self.cond:
      self.control_thread__have_new_seg = True
      self.control_thread__have_new_error_signal = False
      self.seg_name = seg_name
      self.seg_len = seg_len
      self.posteriors = posteriors
      self.error_signal = None
      self.loss = None
      self.asked_for_posteriors = False
      self.notified_for_segment = False
      self.cond.notifyAll()
    loss, error_signal = self.callback("get_loss_and_error_signal", seg_name, seg_len, posteriors)
    assert error_signal.shape == posteriors.shape
    with self.cond:
      self.control_thread__have_new_error_signal = True
      self.posteriors = None
      self.cond.notifyAll()
    numpy_set_unused(posteriors)
    error_signal = error_signal.astype('float32', copy=False)
    return loss, error_signal
예제 #2
0
  def handle_cmd_get_loss_and_error_signal(self, seg_name, seg_len, posteriors):
    """
    :param str seg_name: seg name
    :param int seg_len: the segment length in frames
    :param numpy.ndarray posteriors: 2d (time,label) float array

    See SprintErrorSignals.SprintSubprocessInstance.get_loss_and_error_signal().
    """
    assert isinstance(seg_len, (int, long))
    assert seg_len > 0
    assert posteriors.ndim == 2
    assert posteriors.shape[0] == seg_len
    if Verbose: print("CRNN SprintControl[pid %i] PythonControl handle_cmd_get_loss_and_error_signal: name=%r, len=%r" % (os.getpid(), seg_name, seg_len))
    with self.cond:
      self.control_thread__have_new_seg = True
      self.control_thread__have_new_error_signal = False
      self.seg_name = seg_name
      self.seg_len = seg_len
      self.posteriors = posteriors
      self.error_signal = None
      self.loss = None
      self.asked_for_posteriors = False
      self.notified_for_segment = False
      self.cond.notifyAll()
    loss, error_signal = self.callback("get_loss_and_error_signal", seg_name, seg_len, posteriors)
    assert error_signal.shape == posteriors.shape
    with self.cond:
      self.control_thread__have_new_error_signal = True
      self.posteriors = None
      self.cond.notifyAll()
    numpy_set_unused(posteriors)
    error_signal = error_signal.astype('float32', copy=False)
    return loss, error_signal
예제 #3
0
    def get_batch_loss_and_error_signal(self,
                                        log_posteriors,
                                        seq_lengths,
                                        tags=None):
        """
    :param numpy.ndarray log_posteriors: 3d (time,batch,label)
    :param numpy.ndarray seq_lengths: 1d (batch)
    :param list[str] tags: seq names, length = batch
    :rtype (numpy.ndarray, numpy.ndarray)
    :returns (loss, error_signal). error_signal has the same shape as posteriors.
    loss is a 1d-array (batch).

    Note that this accesses some global references, like global current seg info,
    via the current Device instance.
    Thus this is expected to be run from the Device host proc,
      inside from SprintErrorSigOp.perform.
    This also expects that we don't have chunked seqs.
    """
        assert seq_lengths.ndim == 1
        assert log_posteriors.ndim == 3
        n_batch = seq_lengths.shape[0]
        assert n_batch == log_posteriors.shape[1]

        if tags is None:
            import Device
            assert Device.is_device_host_proc()
            tags = Device.get_current_seq_tags()
        assert len(tags) == n_batch

        batch_loss = numpy.zeros((n_batch, ), dtype="float32")
        batch_error_signal = numpy.zeros_like(log_posteriors, dtype="float32")
        # Very simple parallelism. We must avoid any form of multi-threading
        # because this can be problematic with Theano.
        # See: https://groups.google.com/forum/#!msg/theano-users/Pu4YKlZKwm4/eNcAegzaNeYJ
        # We also try to keep it simple here.
        for bb in range(0, n_batch, self.max_num_instances):
            for i in range(self.max_num_instances):
                b = bb + i
                if b >= n_batch: break
                instance = self._get_instance(i)
                instance.get_loss_and_error_signal__send(
                    seg_name=tags[b],
                    seg_len=seq_lengths[b],
                    log_posteriors=log_posteriors[:seq_lengths[b], b])
            for i in range(self.max_num_instances):
                b = bb + i
                if b >= n_batch: break
                instance = self._get_instance(i)
                seg_name, loss, error_signal = instance.get_loss_and_error_signal__read(
                )
                assert seg_name == tags[b]
                batch_loss[b] = loss
                batch_error_signal[:seq_lengths[b], b] = error_signal
                numpy_set_unused(error_signal)
        return batch_loss, batch_error_signal
예제 #4
0
 def run(self):
   try:
     for b in self.batch_idxs:
       self.instance.get_loss_and_error_signal__send(
         seg_name=self.tags[b], seg_len=self.seq_lengths[b], log_posteriors=self.log_posteriors[:self.seq_lengths[b], b])
       seg_name, loss, error_signal = self.instance.get_loss_and_error_signal__read()
       assert seg_name == self.tags[b]
       self.batch_loss[b] = loss
       self.batch_error_signal[:self.seq_lengths[b], b] = error_signal
       numpy_set_unused(error_signal)
   except Exception as exc:
     self.exception = exc
예제 #5
0
 def run(self):
     try:
         for b in self.batch_idxs:
             self.instance.get_loss_and_error_signal__send(
                 seg_name=self.tags[b],
                 seg_len=self.seq_lengths[b],
                 log_posteriors=self.log_posteriors[:self.seq_lengths[b],
                                                    b])
             seg_name, loss, error_signal = self.instance.get_loss_and_error_signal__read(
             )
             assert seg_name == self.tags[b]
             self.batch_loss[b] = loss
             self.batch_error_signal[:self.seq_lengths[b], b] = error_signal
             numpy_set_unused(error_signal)
     except Exception as exc:
         self.exception = exc
예제 #6
0
  def get_batch_loss_and_error_signal(self, log_posteriors, seq_lengths):
    """
    :param numpy.ndarray log_posteriors: 3d (time,batch,label)
    :param numpy.ndarray seq_lengths: 1d (batch)
    :rtype (numpy.ndarray, numpy.ndarray)
    :returns (loss, error_signal). error_signal has the same shape as posteriors.
    loss is a 1d-array (batch).

    Note that this accesses some global references, like global current seg info,
    via the current Device instance.
    Thus this is expected to be run from the Device host proc,
      inside from SprintErrorSigOp.perform.
    This also expects that we don't have chunked seqs.
    """
    import Device
    assert Device.is_device_host_proc()
    assert seq_lengths.ndim == 1
    assert log_posteriors.ndim == 3
    n_batch = seq_lengths.shape[0]
    assert n_batch == log_posteriors.shape[1]

    tags = Device.get_current_seq_tags()
    assert len(tags) == n_batch

    batch_loss = numpy.zeros((n_batch,), dtype="float32")
    batch_error_signal = numpy.zeros_like(log_posteriors, dtype="float32")
    # Very simple parallelism. We must avoid any form of multi-threading
    # because this can be problematic with Theano.
    # See: https://groups.google.com/forum/#!msg/theano-users/Pu4YKlZKwm4/eNcAegzaNeYJ
    # We also try to keep it simple here.
    for bb in range(0, n_batch, self.max_num_instances):
      for i in range(self.max_num_instances):
        b = bb + i
        if b >= n_batch: break
        instance = self._get_instance(i)
        instance.get_loss_and_error_signal__send(
          seg_name=tags[b], seg_len=seq_lengths[b], log_posteriors=log_posteriors[:seq_lengths[b], b])
      for i in range(self.max_num_instances):
        b = bb + i
        if b >= n_batch: break
        instance = self._get_instance(i)
        seg_name, loss, error_signal = instance.get_loss_and_error_signal__read()
        assert seg_name == tags[b]
        batch_loss[b] = loss
        batch_error_signal[:seq_lengths[b], b] = error_signal
        numpy_set_unused(error_signal)
    return batch_loss, batch_error_signal
예제 #7
0
  def get_batch_loss_and_error_signal(self, log_posteriors, seq_lengths, tags=None):
    """
    :param numpy.ndarray log_posteriors: 3d (time,batch,label)
    :param numpy.ndarray seq_lengths: 1d (batch)
    :param list[str] tags: seq names, length = batch
    :rtype (numpy.ndarray, numpy.ndarray)
    :returns (loss, error_signal). error_signal has the same shape as posteriors.
    loss is a 1d-array (batch).

    Note that this accesses some global references, like global current seg info,
    via the current Device instance.
    Thus this is expected to be run from the Device host proc,
      inside from SprintErrorSigOp.perform.
    This also expects that we don't have chunked seqs.
    """
    assert seq_lengths.ndim == 1
    assert log_posteriors.ndim == 3
    n_batch = seq_lengths.shape[0]
    assert n_batch == log_posteriors.shape[1]

    if tags is None:
      import Device
      assert Device.is_device_host_proc()
      tags = Device.get_current_seq_tags()
    assert len(tags) == n_batch
    
    batch_loss = numpy.zeros((n_batch,), dtype="float32")
    batch_error_signal = numpy.zeros_like(log_posteriors, dtype="float32")
    
    # greedy solution to the scheduling problem
    sorted_length = sorted(enumerate(seq_lengths),key=lambda x:x[1],reverse=True)
    jobs = [ [] for i in range(self.max_num_instances) ]
    joblen = [0]*self.max_num_instances
    for i,l in sorted_length:
      j = min(enumerate(joblen),key=lambda x:x[1])[0]
      jobs[j].append(i)
      joblen[j]+=l

    if not BackendEngine.is_theano_selected() and self.max_num_instances > 1:
      threads = [ReaderThread(self._get_instance(i), i, jobs[i], tags, seq_lengths, log_posteriors, batch_loss, batch_error_signal) for i in range(self.max_num_instances)]
      for i,thread in enumerate(threads):
        thread.join()
        if thread.exception:
          raise thread.exception
    else:
      # Very simple parallelism. We must avoid any form of multi-threading
      # because this can be problematic with Theano.
      # See: https://groups.google.com/forum/#!msg/theano-users/Pu4YKlZKwm4/eNcAegzaNeYJ
      # We also try to keep it simple here.
      for bb in range(0, n_batch, self.max_num_instances):
        for i in range(self.max_num_instances):
          b = bb + i
          if b >= n_batch: break
          instance = self._get_instance(i)
          instance.get_loss_and_error_signal__send(
            seg_name=tags[b], seg_len=seq_lengths[b], log_posteriors=log_posteriors[:seq_lengths[b], b])
        for i in range(self.max_num_instances):
          b = bb + i
          if b >= n_batch: break
          instance = self._get_instance(i)
          seg_name, loss, error_signal = instance.get_loss_and_error_signal__read()
          assert seg_name == tags[b]
          batch_loss[b] = loss
          batch_error_signal[:seq_lengths[b], b] = error_signal
          numpy_set_unused(error_signal)
    return batch_loss, batch_error_signal
예제 #8
0
    def get_batch_loss_and_error_signal(self,
                                        log_posteriors,
                                        seq_lengths,
                                        tags=None):
        """
    :param numpy.ndarray log_posteriors: 3d (time,batch,label)
    :param numpy.ndarray seq_lengths: 1d (batch)
    :param list[str] tags: seq names, length = batch
    :rtype (numpy.ndarray, numpy.ndarray)
    :returns (loss, error_signal). error_signal has the same shape as posteriors.
    loss is a 1d-array (batch).

    Note that this accesses some global references, like global current seg info,
    via the current Device instance.
    Thus this is expected to be run from the Device host proc,
      inside from SprintErrorSigOp.perform.
    This also expects that we don't have chunked seqs.
    """
        assert seq_lengths.ndim == 1
        assert log_posteriors.ndim == 3
        n_batch = seq_lengths.shape[0]
        assert n_batch == log_posteriors.shape[1]

        if tags is None:
            import Device
            assert Device.is_device_host_proc()
            tags = Device.get_current_seq_tags()
        assert len(tags) == n_batch

        batch_loss = numpy.zeros((n_batch, ), dtype="float32")
        batch_error_signal = numpy.zeros_like(log_posteriors, dtype="float32")

        # greedy solution to the scheduling problem
        sorted_length = sorted(enumerate(seq_lengths),
                               key=lambda x: x[1],
                               reverse=True)
        jobs = [[] for i in range(self.max_num_instances)]
        joblen = [0] * self.max_num_instances
        for i, l in sorted_length:
            j = min(enumerate(joblen), key=lambda x: x[1])[0]
            jobs[j].append(i)
            joblen[j] += l

        if not BackendEngine.is_theano_selected(
        ) and self.max_num_instances > 1:
            threads = [
                ReaderThread(self._get_instance(i), i, jobs[i], tags,
                             seq_lengths, log_posteriors, batch_loss,
                             batch_error_signal)
                for i in range(self.max_num_instances)
            ]
            for i, thread in enumerate(threads):
                thread.join()
                if thread.exception:
                    raise thread.exception
        else:
            # Very simple parallelism. We must avoid any form of multi-threading
            # because this can be problematic with Theano.
            # See: https://groups.google.com/forum/#!msg/theano-users/Pu4YKlZKwm4/eNcAegzaNeYJ
            # We also try to keep it simple here.
            for bb in range(0, n_batch, self.max_num_instances):
                for i in range(self.max_num_instances):
                    b = bb + i
                    if b >= n_batch: break
                    instance = self._get_instance(i)
                    instance.get_loss_and_error_signal__send(
                        seg_name=tags[b],
                        seg_len=seq_lengths[b],
                        log_posteriors=log_posteriors[:seq_lengths[b], b])
                for i in range(self.max_num_instances):
                    b = bb + i
                    if b >= n_batch: break
                    instance = self._get_instance(i)
                    seg_name, loss, error_signal = instance.get_loss_and_error_signal__read(
                    )
                    assert seg_name == tags[b]
                    batch_loss[b] = loss
                    batch_error_signal[:seq_lengths[b], b] = error_signal
                    numpy_set_unused(error_signal)
        return batch_loss, batch_error_signal