def get_batch_loss_and_error_signal(self, target, log_posteriors, seq_lengths): """ :param str target: e.g. "classes". not yet passed over to Sprint. :param numpy.ndarray log_posteriors: 3d (time,batch,label) :param numpy.ndarray seq_lengths: 1d (batch) :rtype (numpy.ndarray, numpy.ndarray) :returns (loss, error_signal). error_signal has the same shape as posteriors. loss is a 1d-array (batch). Note that this accesses some global references, like global current seg info. """ assert seq_lengths.ndim == 1 assert log_posteriors.ndim == 3 n_batch = seq_lengths.shape[0] assert n_batch == log_posteriors.shape[1] import Device index = Device.get_current_seq_index(target) # (time,batch) assert index.ndim == 2 assert index.shape[1] == n_batch assert (numpy.sum(index, axis=0) == seq_lengths).all() tags = Device.get_current_seq_tags() assert len(tags) == n_batch batch_loss = numpy.zeros((n_batch,), dtype="float32") batch_error_signal = numpy.zeros_like(log_posteriors, dtype="float32") # Very simple parallelism. We must avoid any form of multi-threading # because this can be problematic with Theano. # See: https://groups.google.com/forum/#!msg/theano-users/Pu4YKlZKwm4/eNcAegzaNeYJ # We also try to keep it simple here. for bb in range(0, n_batch, self.max_num_instances): for i in range(self.max_num_instances): b = bb + i if b >= n_batch: break instance = self._get_instance(i) instance.get_loss_and_error_signal__send( seg_name=tags[b], seg_len=seq_lengths[b], posteriors=log_posteriors[:seq_lengths[b], b]) for i in range(self.max_num_instances): b = bb + i if b >= n_batch: break instance = self._get_instance(i) seg_name, loss, error_signal = instance.get_loss_and_error_signal__read() assert seg_name == tags[b] batch_loss[b] = loss batch_error_signal[:seq_lengths[b], b] = error_signal return batch_loss, batch_error_signal