Example #1
0
    def ComputeLoss(self, theta, predictions, input_batch):
        output_batch = predictions
        ctc_loss = tf.nn.ctc_loss(
            input_batch.tgt.labels,
            output_batch.encoded,
            py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1),
            py_utils.LengthsFromBitMask(output_batch.padding, 0),
            logits_time_major=True,
            blank_index=self.params.blank_index)

        # ctc_loss.shape = (B)
        total_loss = tf.reduce_mean(ctc_loss)
        per_sequence_loss = {'loss': ctc_loss}
        return dict(loss=(total_loss, 1.0)), per_sequence_loss
Example #2
0
    def DecodeWithTheta(self, theta, input_batch):
        """Constructs the inference graph."""
        p = self.params
        # from IPython import embed; embed()
        with tf.name_scope('decode'), tf.name_scope(p.name):
            with tf.name_scope('encoder'):
                encoder_outputs = self._FrontendAndEncoderFProp(
                    theta, input_batch.src)
            if p.inference_compute_only_log_softmax:
                global_step = tf.train.get_global_step()
                increment_global_step = tf.assign(global_step, global_step + 1)
                with tf.control_dependencies([increment_global_step]):
                    log_probabilities = tf.transpose(tf.nn.log_softmax(
                        encoder_outputs.encoded, axis=2),
                                                     perm=(1, 0, 2))
                # encoder_outputs's shape is [T,B,F]
                return {
                    'log_probabilities':
                    log_probabilities,
                    'log_probabilities_lengths':
                    py_utils.LengthsFromBitMask(encoder_outputs.padding, 0),
                    'int64_uttid':
                    input_batch.sample_ids,
                    'int64_audio_document_id':
                    input_batch.audio_document_ids,
                    'num_utterances_in_audio_document':
                    input_batch.num_utterances_in_audio_document
                }
            with tf.name_scope('decoder'):
                decoder_outs = self._DecodeCTC(encoder_outputs)

            decoder_metrics = self._CalculateErrorRates(
                decoder_outs, input_batch)
            return decoder_metrics
Example #3
0
def cpu_tf_graph_decode(tok_logits, padding, blank_index, input_generator):
    # (T, B, F)
    # ctc_beam_search_decoder assumes blank_index=0
    assert blank_index == 31

    # TODO: Consider making beam_width larger
    (decoded, ), _ = tf.nn.ctc_beam_search_decoder(tok_logits,
                                                   py_utils.LengthsFromBitMask(
                                                       padding, 0),
                                                   beam_width=100)
    # Could easily use blank_index here as well, right?
    invalid = tf.constant(-1, tf.int64)
    dense_dec = tf.sparse_to_dense(decoded.indices,
                                   decoded.dense_shape,
                                   decoded.values,
                                   default_value=invalid)

    batch_segments = decoded.indices[:, 0]
    times_in_each_batch = decoded.indices[:, 1]

    decoded_seq_lengths = tf.cast(
        tf.math.segment_max(times_in_each_batch, batch_segments) + 1, tf.int32)
    # What happens if an empty sequence is output??? Then pad appropriately, tada!
    decoded_seq_lengths = py_utils.PadBatchDimension(decoded_seq_lengths,
                                                     tf.shape(tok_logits)[1],
                                                     0)

    hyp_str = py_utils.RunOnTpuHost(input_generator.IdsToStrings,
                                    tf.cast(dense_dec, tf.int32),
                                    decoded_seq_lengths)

    hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '')
    hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '')
    hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '')
    return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
Example #4
0
    def _DecodeCTC(self, output_batch):
        tok_logits = output_batch.encoded  # (T, B, F)
        idxs = list(range(tok_logits.shape[-1]))
        idxs[0] = self.params.blank_index
        idxs[self.params.blank_index] = 0
        tok_logits = tf.stack([tok_logits[:, :, idx] for idx in idxs], axis=-1)

        # GALVEZ: Make beam_width a tunable parameter!
        (decoded, ), _ = py_utils.RunOnTpuHost(tf.nn.ctc_beam_search_decoder,
                                               tok_logits,
                                               py_utils.LengthsFromBitMask(
                                                   output_batch.padding, 0),
                                               beam_width=100)
        # (decoded,), _ = tf.nn.ctc_beam_search_decoder(tok_logits,
        #                                               py_utils.LengthsFromBitMask(
        #                                                   output_batch.padding, 0),
        #                                               beam_width=100)

        return py_utils.NestedMap(sparse_ids=decoded)  #, transcripts=hyp_str)
        dense_dec = tf.sparse_to_dense(decoded.indices,
                                       decoded.dense_shape,
                                       decoded.values,
                                       default_value=self.params.blank_index)

        invalid = tf.constant(self.params.blank_index, tf.int64)
        bit_mask = tf.cast(tf.math.equal(dense_dec, invalid),
                           tf.float32)  # (B, T)

        decoded_seq_lengths = py_utils.LengthsFromBitMask(
            tf.transpose(bit_mask), 0)

        hyp_str = py_utils.RunOnTpuHost(self.input_generator.IdsToStrings,
                                        tf.cast(dense_dec, tf.int32),
                                        decoded_seq_lengths)

        # Some predictions have start and stop tokens predicted, we dont want
        # to include those in WER calculation
        hyp_str = tf.strings.regex_replace(hyp_str, '(<unk>)+', '')
        hyp_str = tf.strings.regex_replace(hyp_str, '(<s>)+', '')
        hyp_str = tf.strings.regex_replace(hyp_str, '(</s>)+', '')
        return py_utils.NestedMap(sparse_ids=decoded, transcripts=hyp_str)
Example #5
0
    def _CalculateErrorRates(self, dec_outs_dict, input_batch):
        return {'stuff': dec_outs_dict.sparse_ids.values}
        gt_seq_lens = py_utils.LengthsFromBitMask(input_batch.tgt.paddings, 1)
        gt_transcripts = py_utils.RunOnTpuHost(
            self.input_generator.IdsToStrings, input_batch.tgt.labels,
            gt_seq_lens)

        # token error rate
        char_dist = tf.edit_distance(tf.string_split(dec_outs_dict.transcripts,
                                                     sep=''),
                                     tf.string_split(gt_transcripts, sep=''),
                                     normalize=False)

        ref_chars = tf.strings.length(gt_transcripts)
        num_wrong_chars = tf.reduce_sum(char_dist)
        num_ref_chars = tf.cast(tf.reduce_sum(ref_chars), tf.float32)
        cer = num_wrong_chars / num_ref_chars

        # word error rate
        word_dist = decoder_utils.ComputeWer(dec_outs_dict.transcripts,
                                             gt_transcripts)  # (B, 2)
        num_wrong_words = tf.reduce_sum(word_dist[:, 0])
        num_ref_words = tf.reduce_sum(word_dist[:, 1])
        wer = num_wrong_words / num_ref_words

        ret_dict = {
            'target_ids': input_batch.tgt.ids,
            'target_labels': input_batch.tgt.labels,
            'target_weights': input_batch.tgt.weights,
            'target_paddings': input_batch.tgt.paddings,
            'target_transcripts': gt_transcripts,
            'decoded_transcripts': dec_outs_dict.transcripts,
            'wer': wer,
            'cer': cer,
            'num_wrong_words': num_wrong_words,
            'num_ref_words': num_ref_words,
            'num_wrong_chars': num_wrong_chars,
            'num_ref_chars': num_ref_chars
        }
        if not py_utils.use_tpu():
            ret_dict['utt_id'] = input_batch.sample_ids

        return ret_dict