コード例 #1
0
 def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor):
     """
     Returns:
         path, features, input_lengths, labels, label_lengths, pred_inp
     """
     data = self.tf_preprocess(path, audio, indices) if self.use_tf else self.preprocess(path, audio, indices)
     _, features, input_length, label, label_length, prediction, prediction_length = data
     return (
         data_util.create_inputs(
             inputs=features, inputs_length=input_length, predictions=prediction, predictions_length=prediction_length
         ),
         data_util.create_labels(labels=label, labels_length=label_length),
     )
コード例 #2
0
    def process(self, dataset, batch_size):
        dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE)
        self.total_steps = math_util.get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder)

        if self.cache:
            dataset = dataset.cache()

        if self.shuffle:
            dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True)

        if self.indefinite and self.total_steps:
            dataset = dataset.repeat()

        # PADDED BATCH the dataset
        dataset = dataset.padded_batch(
            batch_size=batch_size,
            padded_shapes=(
                data_util.create_inputs(
                    inputs=tf.TensorShape(self.speech_featurizer.shape),
                    inputs_length=tf.TensorShape([]),
                    predictions=tf.TensorShape(self.text_featurizer.prepand_shape),
                    predictions_length=tf.TensorShape([]),
                ),
                data_util.create_labels(labels=tf.TensorShape(self.text_featurizer.shape), labels_length=tf.TensorShape([])),
            ),
            padding_values=(
                data_util.create_inputs(
                    inputs=0.0, inputs_length=0, predictions=self.text_featurizer.blank, predictions_length=0
                ),
                data_util.create_labels(labels=self.text_featurizer.blank, labels_length=0),
            ),
            drop_remainder=self.drop_remainder,
        )

        # PREFETCH to improve speed of input length
        dataset = dataset.prefetch(AUTOTUNE)
        return dataset
コード例 #3
0
 def make(
     self,
     input_shape,
     batch_size=None,
 ):
     inputs = tf.keras.Input(input_shape,
                             batch_size=batch_size,
                             dtype=tf.float32)
     inputs_length = tf.keras.Input(shape=[],
                                    batch_size=batch_size,
                                    dtype=tf.int32)
     self(
         data_util.create_inputs(
             inputs=inputs,
             inputs_length=inputs_length,
         ),
         training=False,
     )
コード例 #4
0
ファイル: rnn_transducer.py プロジェクト: ck196/TensorFlowASR
# build model
rnnt = RnnTransducer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
rnnt.make(speech_featurizer.shape)
rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True)
rnnt.summary(line_length=120)
rnnt.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
features = speech_featurizer.tf_extract(signal)
input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor)

if args.beam_width:
    transcript = rnnt.recognize_beam(
        data_util.create_inputs(
            inputs=features[None, ...],
            inputs_length=input_length[None, ...]
        )
    )
    logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
    transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp(
        signal=signal,
        predicted=tf.constant(text_featurizer.blank, dtype=tf.int32),
        encoder_states=rnnt.encoder.get_initial_state(),
        prediction_states=rnnt.predict_net.get_initial_state()
    )
    logger.info("Transcript:", transcript)
    logger.info("Start time:", stime)
    logger.info("End time:", etime)
else:
    transcript = rnnt.recognize(
コード例 #5
0
    text_featurizer = CharFeaturizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
features = speech_featurizer.tf_extract(signal)
input_length = tf.shape(features)[0]

if args.beam_width:
    inputs = create_inputs(features[None, ...], input_length[None, ...])
    transcript = conformer.recognize_beam(inputs)
    logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}")
elif args.timestamp:
    transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
        signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()
    )
    logger.info(f"Transcript: {transcript}")
    logger.info(f"Start time: {stime}")
    logger.info(f"End time: {etime}")
else:
    code_points, _, _ = conformer.recognize_tflite(
        signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()
    )
    transcript = tf.strings.unicode_encode(code_points, "UTF-8").numpy().decode("UTF-8")
    logger.info(f"Transcript: {transcript}")