def ComputeLoss(self, theta, predictions, input_batch): output_batch = predictions ctc_loss = tf.nn.ctc_loss( input_batch.tgt.labels, output_batch.encoded, py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1), py_utils.LengthsFromBitMask(output_batch.padding, 0), logits_time_major=True, blank_index=self.params.blank_index) # ctc_loss.shape = (B) total_loss = tf.reduce_mean(ctc_loss) per_sequence_loss = {'loss': ctc_loss} return dict(loss=(total_loss, 1.0)), per_sequence_loss
def DecodeWithTheta(self, theta, input_batch): """Constructs the inference graph.""" p = self.params # from IPython import embed; embed() with tf.name_scope('decode'), tf.name_scope(p.name): with tf.name_scope('encoder'): encoder_outputs = self._FrontendAndEncoderFProp( theta, input_batch.src) if p.inference_compute_only_log_softmax: global_step = tf.train.get_global_step() increment_global_step = tf.assign(global_step, global_step + 1) with tf.control_dependencies([increment_global_step]): log_probabilities = tf.transpose(tf.nn.log_softmax( encoder_outputs.encoded, axis=2), perm=(1, 0, 2)) # encoder_outputs's shape is [T,B,F] return { 'log_probabilities': log_probabilities, 'log_probabilities_lengths': py_utils.LengthsFromBitMask(encoder_outputs.padding, 0), 'int64_uttid': input_batch.sample_ids, 'int64_audio_document_id': input_batch.audio_document_ids, 'num_utterances_in_audio_document': input_batch.num_utterances_in_audio_document } with tf.name_scope('decoder'): decoder_outs = self._DecodeCTC(encoder_outputs) decoder_metrics = self._CalculateErrorRates( decoder_outs, input_batch) return decoder_metrics
def cpu_tf_graph_decode(tok_logits, padding, blank_index, input_generator): # (T, B, F) # ctc_beam_search_decoder assumes blank_index=0 assert blank_index == 31 # TODO: Consider making beam_width larger (decoded, ), _ = tf.nn.ctc_beam_search_decoder(tok_logits, py_utils.LengthsFromBitMask( padding, 0), beam_width=100) # Could easily use blank_index here as well, right? invalid = tf.constant(-1, tf.int64) dense_dec = tf.sparse_to_dense(decoded.indices, decoded.dense_shape, decoded.values, default_value=invalid) batch_segments = decoded.indices[:, 0] times_in_each_batch = decoded.indices[:, 1] decoded_seq_lengths = tf.cast( tf.math.segment_max(times_in_each_batch, batch_segments) + 1, tf.int32) # What happens if an empty sequence is output??? Then pad appropriately, tada! decoded_seq_lengths = py_utils.PadBatchDimension(decoded_seq_lengths, tf.shape(tok_logits)[1], 0) hyp_str = py_utils.RunOnTpuHost(input_generator.IdsToStrings, tf.cast(dense_dec, tf.int32), decoded_seq_lengths) hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '') return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
def _DecodeCTC(self, output_batch): tok_logits = output_batch.encoded # (T, B, F) idxs = list(range(tok_logits.shape[-1])) idxs[0] = self.params.blank_index idxs[self.params.blank_index] = 0 tok_logits = tf.stack([tok_logits[:, :, idx] for idx in idxs], axis=-1) # GALVEZ: Make beam_width a tunable parameter! (decoded, ), _ = py_utils.RunOnTpuHost(tf.nn.ctc_beam_search_decoder, tok_logits, py_utils.LengthsFromBitMask( output_batch.padding, 0), beam_width=100) # (decoded,), _ = tf.nn.ctc_beam_search_decoder(tok_logits, # py_utils.LengthsFromBitMask( # output_batch.padding, 0), # beam_width=100) return py_utils.NestedMap(sparse_ids=decoded) #, transcripts=hyp_str) dense_dec = tf.sparse_to_dense(decoded.indices, decoded.dense_shape, decoded.values, default_value=self.params.blank_index) invalid = tf.constant(self.params.blank_index, tf.int64) bit_mask = tf.cast(tf.math.equal(dense_dec, invalid), tf.float32) # (B, T) decoded_seq_lengths = py_utils.LengthsFromBitMask( tf.transpose(bit_mask), 0) hyp_str = py_utils.RunOnTpuHost(self.input_generator.IdsToStrings, tf.cast(dense_dec, tf.int32), decoded_seq_lengths) # Some predictions have start and stop tokens predicted, we dont want # to include those in WER calculation hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '') return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
def _CalculateErrorRates(self, dec_outs_dict, input_batch): return {'stuff': dec_outs_dict.sparse_ids.values} gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1) gt_transcripts = py_utils.RunOnTpuHost( self.input_generator.IdsToStrings, input_batch.tgt.labels, gt_seq_lens) # token error rate char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts, sep=''), tf.string_split(gt_transcripts, sep=''), normalize=False) ref_chars = tf.strings.length(gt_transcripts) num_wrong_chars = tf.reduce_sum(char_dist) num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32) cer = num_wrong_chars / num_ref_chars # word error rate word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts, gt_transcripts) # (B, 2) num_wrong_words = tf.reduce_sum(word_dist[:, 0]) num_ref_words = tf.reduce_sum(word_dist[:, 1]) wer = num_wrong_words / num_ref_words ret_dict = { 'target_ids': input_batch.tgt.ids, 'target_labels': input_batch.tgt.labels, 'target_weights': input_batch.tgt.weights, 'target_paddings': input_batch.tgt.paddings, 'target_transcripts': gt_transcripts, 'decoded_transcripts': dec_outs_dict.transcripts, 'wer': wer, 'cer': cer, 'num_wrong_words': num_wrong_words, 'num_ref_words': num_ref_words, 'num_wrong_chars': num_wrong_chars, 'num_ref_chars': num_ref_chars } if not py_utils.use_tpu(): ret_dict['utt_id'] = input_batch.sample_ids return ret_dict