Esempio n. 1
0
def cpu_tf_graph_decode(tok_logits, padding, blank_index, input_generator):
    # (T, B, F)
    # ctc_beam_search_decoder assumes blank_index=0
    assert blank_index == 31

    # TODO: Consider making beam_width larger
    (decoded, ), _ = tf.nn.ctc_beam_search_decoder(tok_logits,
                                                   py_utils.LengthsFromBitMask(
                                                       padding, 0),
                                                   beam_width=100)
    # Could easily use blank_index here as well, right?
    invalid = tf.constant(-1, tf.int64)
    dense_dec = tf.sparse_to_dense(decoded.indices,
                                   decoded.dense_shape,
                                   decoded.values,
                                   default_value=invalid)

    batch_segments = decoded.indices[:, 0]
    times_in_each_batch = decoded.indices[:, 1]

    decoded_seq_lengths = tf.cast(
        tf.math.segment_max(times_in_each_batch, batch_segments) + 1, tf.int32)
    # What happens if an empty sequence is output??? Then pad appropriately, tada!
    decoded_seq_lengths = py_utils.PadBatchDimension(decoded_seq_lengths,
                                                     tf.shape(tok_logits)[1],
                                                     0)

    hyp_str = py_utils.RunOnTpuHost(input_generator.IdsToStrings,
                                    tf.cast(dense_dec, tf.int32),
                                    decoded_seq_lengths)

    hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '')
    hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '')
    hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '')
    return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
Esempio n. 2
0
    def _DecodeCTC(self, output_batch):
        tok_logits = output_batch.encoded  # (T, B, F)
        idxs = list(range(tok_logits.shape[-1]))
        idxs[0] = self.params.blank_index
        idxs[self.params.blank_index] = 0
        tok_logits = tf.stack([tok_logits[:, :, idx] for idx in idxs], axis=-1)

        # GALVEZ: Make beam_width a tunable parameter!
        (decoded, ), _ = py_utils.RunOnTpuHost(tf.nn.ctc_beam_search_decoder,
                                               tok_logits,
                                               py_utils.LengthsFromBitMask(
                                                   output_batch.padding, 0),
                                               beam_width=100)
        # (decoded,), _ = tf.nn.ctc_beam_search_decoder(tok_logits,
        #                                               py_utils.LengthsFromBitMask(
        #                                                   output_batch.padding, 0),
        #                                               beam_width=100)

        return py_utils.NestedMap(sparse_ids=decoded)  #, transcripts=hyp_str)
        dense_dec = tf.sparse_to_dense(decoded.indices,
                                       decoded.dense_shape,
                                       decoded.values,
                                       default_value=self.params.blank_index)

        invalid = tf.constant(self.params.blank_index, tf.int64)
        bit_mask = tf.cast(tf.math.equal(dense_dec, invalid),
                           tf.float32)  # (B, T)

        decoded_seq_lengths = py_utils.LengthsFromBitMask(
            tf.transpose(bit_mask), 0)

        hyp_str = py_utils.RunOnTpuHost(self.input_generator.IdsToStrings,
                                        tf.cast(dense_dec, tf.int32),
                                        decoded_seq_lengths)

        # Some predictions have start and stop tokens predicted, we dont want
        # to include those in WER calculation
        hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '')
        hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '')
        hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '')
        return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
Esempio n. 3
0
def _Dense(sparse, default_value=0):
    return tf.sparse_to_dense(sparse_indices=sparse.indices,
                              output_shape=sparse.dense_shape,
                              sparse_values=sparse.values,
                              default_value=default_value)