def cpu_tf_graph_decode(tok_logits, padding, blank_index, input_generator): # (T, B, F) # ctc_beam_search_decoder assumes blank_index=0 assert blank_index == 31 # TODO: Consider making beam_width larger (decoded, ), _ = tf.nn.ctc_beam_search_decoder(tok_logits, py_utils.LengthsFromBitMask( padding, 0), beam_width=100) # Could easily use blank_index here as well, right? invalid = tf.constant(-1, tf.int64) dense_dec = tf.sparse_to_dense(decoded.indices, decoded.dense_shape, decoded.values, default_value=invalid) batch_segments = decoded.indices[:, 0] times_in_each_batch = decoded.indices[:, 1] decoded_seq_lengths = tf.cast( tf.math.segment_max(times_in_each_batch, batch_segments) + 1, tf.int32) # What happens if an empty sequence is output??? Then pad appropriately, tada! decoded_seq_lengths = py_utils.PadBatchDimension(decoded_seq_lengths, tf.shape(tok_logits)[1], 0) hyp_str = py_utils.RunOnTpuHost(input_generator.IdsToStrings, tf.cast(dense_dec, tf.int32), decoded_seq_lengths) hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '') return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
def _DecodeCTC(self, output_batch): tok_logits = output_batch.encoded # (T, B, F) idxs = list(range(tok_logits.shape[-1])) idxs[0] = self.params.blank_index idxs[self.params.blank_index] = 0 tok_logits = tf.stack([tok_logits[:, :, idx] for idx in idxs], axis=-1) # GALVEZ: Make beam_width a tunable parameter! (decoded, ), _ = py_utils.RunOnTpuHost(tf.nn.ctc_beam_search_decoder, tok_logits, py_utils.LengthsFromBitMask( output_batch.padding, 0), beam_width=100) # (decoded,), _ = tf.nn.ctc_beam_search_decoder(tok_logits, # py_utils.LengthsFromBitMask( # output_batch.padding, 0), # beam_width=100) return py_utils.NestedMap(sparse_ids=decoded) #, transcripts=hyp_str) dense_dec = tf.sparse_to_dense(decoded.indices, decoded.dense_shape, decoded.values, default_value=self.params.blank_index) invalid = tf.constant(self.params.blank_index, tf.int64) bit_mask = tf.cast(tf.math.equal(dense_dec, invalid), tf.float32) # (B, T) decoded_seq_lengths = py_utils.LengthsFromBitMask( tf.transpose(bit_mask), 0) hyp_str = py_utils.RunOnTpuHost(self.input_generator.IdsToStrings, tf.cast(dense_dec, tf.int32), decoded_seq_lengths) # Some predictions have start and stop tokens predicted, we dont want # to include those in WER calculation hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '') hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '') return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
def _Dense(sparse, default_value=0): return tf.sparse_to_dense(sparse_indices=sparse.indices, output_shape=sparse.dense_shape, sparse_values=sparse.values, default_value=default_value)