예제 #1
0
  def DecodeWithTheta(self, theta, input_batch):
    """Constructs the inference graph."""
    p = self.params
    with tf.name_scope('decode'), tf.name_scope(p.name):
      with tf.name_scope('encoder'):
        encoder_outputs = self._FrontendAndEncoderFProp(theta, input_batch.src)
      with tf.name_scope('beam_search'):
        decoder_outs = self.decoder.BeamSearchDecodeWithTheta(
            theta.decoder, encoder_outputs)

      if py_utils.use_tpu():
        # Decoder metric computation contains arbitrary execution
        # that may not run on TPU.
        dec_metrics = py_utils.RunOnTpuHost(self._ComputeDecoderMetrics,
                                            decoder_outs, input_batch)
      else:
        dec_metrics = self._ComputeDecoderMetrics(decoder_outs, input_batch)
      return dec_metrics
예제 #2
0
    def _CalculateErrorRates(self, dec_outs_dict, input_batch):
        return {'stuff': dec_outs_dict.sparse_ids.values}
        gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1)
        gt_transcripts = py_utils.RunOnTpuHost(
            self.input_generator.IdsToStrings, input_batch.tgt.labels,
            gt_seq_lens)

        # token error rate
        char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts,
                                                     sep=''),
                                     tf.string_split(gt_transcripts, sep=''),
                                     normalize=False)

        ref_chars = tf.strings.length(gt_transcripts)
        num_wrong_chars = tf.reduce_sum(char_dist)
        num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32)
        cer = num_wrong_chars / num_ref_chars

        # word error rate
        word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts,
                                             gt_transcripts)  # (B, 2)
        num_wrong_words = tf.reduce_sum(word_dist[:, 0])
        num_ref_words = tf.reduce_sum(word_dist[:, 1])
        wer = num_wrong_words / num_ref_words

        ret_dict = {
            'target_ids': input_batch.tgt.ids,
            'target_labels': input_batch.tgt.labels,
            'target_weights': input_batch.tgt.weights,
            'target_paddings': input_batch.tgt.paddings,
            'target_transcripts': gt_transcripts,
            'decoded_transcripts': dec_outs_dict.transcripts,
            'wer': wer,
            'cer': cer,
            'num_wrong_words': num_wrong_words,
            'num_ref_words': num_ref_words,
            'num_wrong_chars': num_wrong_chars,
            'num_ref_chars': num_ref_chars
        }
        if not py_utils.use_tpu():
            ret_dict['utt_id'] = input_batch.sample_ids

        return ret_dict
예제 #3
0
    def DecodeWithTheta(self, theta, input_batch):
        """Constructs the inference graph."""
        p = self.params
        # from IPython import embed; embed()
        with tf.name_scope('decode'), tf.name_scope(p.name):
            with tf.name_scope('encoder'):
                encoder_outputs = self._FrontendAndEncoderFProp(
                    theta, input_batch.src)
            if p.inference_compute_only_log_softmax:
                global_step = tf.train.get_global_step()
                increment_global_step = tf.assign(global_step, global_step + 1)
                with tf.control_dependencies([increment_global_step]):
                    log_probabilities = tf.transpose(tf.nn.log_softmax(
                        encoder_outputs.encoded, axis=2),
                                                     perm=(1, 0, 2))
                with tf.name_scope('decoder'):
                    decoder_outs = self._DecodeCTC(encoder_outputs)
                # encoder_outputs's shape is [T,B,F]
                return {
                    'log_probabilities':
                    log_probabilities,
                    'log_probabilities_lengths':
                    py_utils.LengthsFromBitMask(encoder_outputs.padding, 0),
                    'int64_uttid':
                    input_batch.sample_ids,
                    'int64_audio_document_id':
                    input_batch.audio_document_ids,
                    'num_utterances_in_audio_document':
                    input_batch.num_utterances_in_audio_document,
                    'transcripts':
                    decoder_outs.transcripts,
                }
            with tf.name_scope('decoder'):
                decoder_outs = self._DecodeCTC(encoder_outputs)

            decoder_metrics = py_utils.RunOnTpuHost(self._CalculateErrorRates,
                                                    decoder_outs, input_batch)
            return decoder_metrics