def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): """ Returns: path, features, input_lengths, labels, label_lengths, pred_inp """ data = self.tf_preprocess(path, audio, indices) if self.use_tf else self.preprocess(path, audio, indices) _, features, input_length, label, label_length, prediction, prediction_length = data return ( data_util.create_inputs( inputs=features, inputs_length=input_length, predictions=prediction, predictions_length=prediction_length ), data_util.create_labels(labels=label, labels_length=label_length), )
def process(self, dataset, batch_size): dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) self.total_steps = math_util.get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) if self.cache: dataset = dataset.cache() if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) if self.indefinite and self.total_steps: dataset = dataset.repeat() # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=( data_util.create_inputs( inputs=tf.TensorShape(self.speech_featurizer.shape), inputs_length=tf.TensorShape([]), predictions=tf.TensorShape(self.text_featurizer.prepand_shape), predictions_length=tf.TensorShape([]), ), data_util.create_labels(labels=tf.TensorShape(self.text_featurizer.shape), labels_length=tf.TensorShape([])), ), padding_values=( data_util.create_inputs( inputs=0.0, inputs_length=0, predictions=self.text_featurizer.blank, predictions_length=0 ), data_util.create_labels(labels=self.text_featurizer.blank, labels_length=0), ), drop_remainder=self.drop_remainder, ) # PREFETCH to improve speed of input length dataset = dataset.prefetch(AUTOTUNE) return dataset
def make( self, input_shape, batch_size=None, ): inputs = tf.keras.Input(input_shape, batch_size=batch_size, dtype=tf.float32) inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) self( data_util.create_inputs( inputs=inputs, inputs_length=inputs_length, ), training=False, )
# build model rnnt = RnnTransducer(**config.model_config, vocabulary_size=text_featurizer.num_classes) rnnt.make(speech_featurizer.shape) rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True) rnnt.summary(line_length=120) rnnt.add_featurizers(speech_featurizer, text_featurizer) signal = read_raw_audio(args.filename) features = speech_featurizer.tf_extract(signal) input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor) if args.beam_width: transcript = rnnt.recognize_beam( data_util.create_inputs( inputs=features[None, ...], inputs_length=input_length[None, ...] ) ) logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) elif args.timestamp: transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp( signal=signal, predicted=tf.constant(text_featurizer.blank, dtype=tf.int32), encoder_states=rnnt.encoder.get_initial_state(), prediction_states=rnnt.predict_net.get_initial_state() ) logger.info("Transcript:", transcript) logger.info("Start time:", stime) logger.info("End time:", etime) else: transcript = rnnt.recognize(
text_featurizer = CharFeaturizer(config.decoder_config) text_featurizer.decoder_config.beam_width = args.beam_width # build model conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) conformer.make(speech_featurizer.shape) conformer.load_weights(args.saved, by_name=True, skip_mismatch=True) conformer.summary(line_length=120) conformer.add_featurizers(speech_featurizer, text_featurizer) signal = read_raw_audio(args.filename) features = speech_featurizer.tf_extract(signal) input_length = tf.shape(features)[0] if args.beam_width: inputs = create_inputs(features[None, ...], input_length[None, ...]) transcript = conformer.recognize_beam(inputs) logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}") elif args.timestamp: transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp( signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state() ) logger.info(f"Transcript: {transcript}") logger.info(f"Start time: {stime}") logger.info(f"End time: {etime}") else: code_points, _, _ = conformer.recognize_tflite( signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state() ) transcript = tf.strings.unicode_encode(code_points, "UTF-8").numpy().decode("UTF-8") logger.info(f"Transcript: {transcript}")