Ejemplo n.º 1
0
    def recognize_beam_tflite(
        self,
        signal,
    ):
        """
        Function to convert to tflite using beam search decoding
        Args:
            signal: tf.Tensor with shape [None] indicating a single audio signal

        Return:
            transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32
        """
        features = self.speech_featurizer.tf_extract(signal)
        features = tf.expand_dims(features, axis=0)
        input_length = shape_util.shape_list(features)[1]
        input_length = math_util.get_reduced_length(input_length,
                                                    self.time_reduction_factor)
        input_length = tf.expand_dims(input_length, axis=0)
        logits = self.encoder(features, training=False)
        logits = self.decoder(logits, training=False)
        probs = tf.nn.softmax(logits)
        decoded = tf.keras.backend.ctc_decode(
            y_pred=probs,
            input_length=input_length,
            greedy=False,
            beam_width=self.text_featurizer.decoder_config.beam_width,
        )
        decoded = tf.cast(decoded[0][0][0], dtype=tf.int32)
        transcript = self.text_featurizer.indices2upoints(decoded)
        return transcript
Ejemplo n.º 2
0
 def call(
     self,
     inputs,
     training=False,
     **kwargs,
 ):
     enc = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=training, **kwargs)
     pred = self.predict_net([inputs["predictions"], inputs["predictions_length"]], training=training, **kwargs)
     logits = self.joint_net([enc, pred], training=training, **kwargs)
     return data_util.create_logits(
         logits=logits, logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
     )
Ejemplo n.º 3
0
 def call(
     self,
     inputs,
     training=False,
     **kwargs,
 ):
     logits = self.encoder(inputs["inputs"], training=training, **kwargs)
     logits = self.decoder(logits, training=training, **kwargs)
     return data_util.create_logits(
         logits=logits,
         logits_length=math_util.get_reduced_length(
             inputs["inputs_length"], self.time_reduction_factor),
     )
Ejemplo n.º 4
0
    def recognize(
        self,
        inputs: Dict[str, tf.Tensor],
    ):
        """
        RNN Transducer Greedy decoding
        Args:
            features (tf.Tensor): a batch of padded extracted features

        Returns:
            tf.Tensor: a batch of decoded transcripts
        """
        encoded = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=False)
        encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
        return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length)
Ejemplo n.º 5
0
    def recognize(
        self,
        inputs: Dict[str, tf.Tensor],
    ):
        """
        RNN Transducer Greedy decoding
        Args:
            features (tf.Tensor): a batch of padded extracted features

        Returns:
            tf.Tensor: a batch of decoded transcripts
        """
        batch_size, _, _, _ = shape_util.shape_list(inputs["inputs"])
        encoded, _ = self.encoder.recognize(inputs["inputs"], self.encoder.get_initial_state(batch_size))
        encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
        return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length)
Ejemplo n.º 6
0
    def recognize_beam(
        self,
        inputs: Dict[str, tf.Tensor],
        lm: bool = False,
    ):
        """
        RNN Transducer Beam Search
        Args:
            features (tf.Tensor): a batch of padded extracted features
            lm (bool, optional): whether to use language model. Defaults to False.

        Returns:
            tf.Tensor: a batch of decoded transcripts
        """
        encoded = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=False)
        encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
        return self._perform_beam_search_batch(encoded=encoded, encoded_length=encoded_length, lm=lm)
Ejemplo n.º 7
0
 def call(
     self,
     inputs,
     training=False,
     **kwargs,
 ):
     features, input_length = inputs
     outputs = features
     for conv in self.convs:
         outputs = conv(outputs, training=training)
     outputs = self.last_conv(outputs, training=training)
     input_length = math_util.get_reduced_length(input_length,
                                                 self.last_conv.strides)
     outputs = self.se([outputs, input_length], training=training)
     if self.residual is not None:
         res = self.residual(features, training=training)
         outputs = tf.add(outputs, res)
     outputs = self.activation(outputs)
     return outputs, input_length
Ejemplo n.º 8
0
                                                       args.subwords)
else:
    text_featurizer = CharFeaturizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
conformer = Conformer(**config.model_config,
                      vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
features = speech_featurizer.tf_extract(signal)
input_length = math_util.get_reduced_length(
    tf.shape(features)[0], conformer.time_reduction_factor)

if args.beam_width:
    transcript = conformer.recognize_beam(features[None, ...],
                                          input_length[None, ...])
    print("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
    transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
        signal, tf.constant(text_featurizer.blank, dtype=tf.int32),
        conformer.predict_net.get_initial_state())
    print("Transcript:", transcript)
    print("Start time:", stime)
    print("End time:", etime)
else:
    transcript, _, _ = conformer.recognize_tflite(
        signal, tf.constant(text_featurizer.blank, dtype=tf.int32),