def DecodeWithTheta(self, theta, input_batch): """Constructs the inference graph.""" p = self.params with tf.name_scope('decode'), tf.name_scope(p.name): with tf.name_scope('encoder'): encoder_outputs = self._FrontendAndEncoderFProp(theta, input_batch.src) with tf.name_scope('beam_search'): decoder_outs = self.decoder.BeamSearchDecodeWithTheta( theta.decoder, encoder_outputs) if py_utils.use_tpu(): # Decoder metric computation contains arbitrary execution # that may not run on TPU. dec_metrics = py_utils.RunOnTpuHost(self._ComputeDecoderMetrics, decoder_outs, input_batch) else: dec_metrics = self._ComputeDecoderMetrics(decoder_outs, input_batch) return dec_metrics
def _CalculateErrorRates(self, dec_outs_dict, input_batch): return {'stuff': dec_outs_dict.sparse_ids.values} gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1) gt_transcripts = py_utils.RunOnTpuHost( self.input_generator.IdsToStrings, input_batch.tgt.labels, gt_seq_lens) # token error rate char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts, sep=''), tf.string_split(gt_transcripts, sep=''), normalize=False) ref_chars = tf.strings.length(gt_transcripts) num_wrong_chars = tf.reduce_sum(char_dist) num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32) cer = num_wrong_chars / num_ref_chars # word error rate word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts, gt_transcripts) # (B, 2) num_wrong_words = tf.reduce_sum(word_dist[:, 0]) num_ref_words = tf.reduce_sum(word_dist[:, 1]) wer = num_wrong_words / num_ref_words ret_dict = { 'target_ids': input_batch.tgt.ids, 'target_labels': input_batch.tgt.labels, 'target_weights': input_batch.tgt.weights, 'target_paddings': input_batch.tgt.paddings, 'target_transcripts': gt_transcripts, 'decoded_transcripts': dec_outs_dict.transcripts, 'wer': wer, 'cer': cer, 'num_wrong_words': num_wrong_words, 'num_ref_words': num_ref_words, 'num_wrong_chars': num_wrong_chars, 'num_ref_chars': num_ref_chars } if not py_utils.use_tpu(): ret_dict['utt_id'] = input_batch.sample_ids return ret_dict
def DecodeWithTheta(self, theta, input_batch): """Constructs the inference graph.""" p = self.params # from IPython import embed; embed() with tf.name_scope('decode'), tf.name_scope(p.name): with tf.name_scope('encoder'): encoder_outputs = self._FrontendAndEncoderFProp( theta, input_batch.src) if p.inference_compute_only_log_softmax: global_step = tf.train.get_global_step() increment_global_step = tf.assign(global_step, global_step + 1) with tf.control_dependencies([increment_global_step]): log_probabilities = tf.transpose(tf.nn.log_softmax( encoder_outputs.encoded, axis=2), perm=(1, 0, 2)) with tf.name_scope('decoder'): decoder_outs = self._DecodeCTC(encoder_outputs) # encoder_outputs's shape is [T,B,F] return { 'log_probabilities': log_probabilities, 'log_probabilities_lengths': py_utils.LengthsFromBitMask(encoder_outputs.padding, 0), 'int64_uttid': input_batch.sample_ids, 'int64_audio_document_id': input_batch.audio_document_ids, 'num_utterances_in_audio_document': input_batch.num_utterances_in_audio_document, 'transcripts': decoder_outs.transcripts, } with tf.name_scope('decoder'): decoder_outs = self._DecodeCTC(encoder_outputs) decoder_metrics = py_utils.RunOnTpuHost(self._CalculateErrorRates, decoder_outs, input_batch) return decoder_metrics