def main(
    config: str = DEFAULT_YAML,
    h5: str = None,
    subwords: bool = False,
    sentence_piece: bool = False,
    output: str = None,
):
    assert h5 and output
    tf.keras.backend.clear_session()
    tf.compat.v1.enable_control_flow_v2()

    config = Config(config)
    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    conformer = Conformer(**config.model_config,
                          vocabulary_size=text_featurizer.num_classes)
    conformer.make(speech_featurizer.shape)
    conformer.load_weights(h5, by_name=True)
    conformer.summary(line_length=100)
    conformer.add_featurizers(speech_featurizer, text_featurizer)

    exec_helpers.convert_tflite(model=conformer, output=output)
Beispiel #2
0
def main(
    config: str = DEFAULT_YAML,
    h5: str = None,
    sentence_piece: bool = False,
    subwords: bool = False,
    output_dir: str = None,
):
    assert h5 and output_dir
    config = Config(config)
    tf.random.set_seed(0)
    tf.keras.backend.clear_session()

    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    # build model
    conformer = Conformer(**config.model_config,
                          vocabulary_size=text_featurizer.num_classes)
    conformer.make(speech_featurizer.shape)
    conformer.load_weights(h5, by_name=True)
    conformer.summary(line_length=100)
    conformer.add_featurizers(speech_featurizer, text_featurizer)

    class ConformerModule(tf.Module):
        def __init__(self, model: Conformer, name=None):
            super().__init__(name=name)
            self.model = model
            self.num_rnns = config.model_config["prediction_num_rnns"]
            self.rnn_units = config.model_config["prediction_rnn_units"]
            self.rnn_nstates = 2 if config.model_config[
                "prediction_rnn_type"] == "lstm" else 1

        @tf.function(
            input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
        def pred(self, signal):
            predicted = tf.constant(0, dtype=tf.int32)
            states = tf.zeros(
                [self.num_rnns, self.rnn_nstates, 1, self.rnn_units],
                dtype=tf.float32)
            features = self.model.speech_featurizer.tf_extract(signal)
            encoded = self.model.encoder_inference(features)
            hypothesis = self.model._perform_greedy(encoded,
                                                    tf.shape(encoded)[0],
                                                    predicted,
                                                    states,
                                                    tflite=False)
            transcript = self.model.text_featurizer.indices2upoints(
                hypothesis.prediction)
            return transcript

    module = ConformerModule(model=conformer)
    tf.saved_model.save(module,
                        export_dir=output_dir,
                        signatures=module.pred.get_concrete_function())
Beispiel #3
0
def main(
    config: str = DEFAULT_YAML,
    saved: str = None,
    mxp: bool = False,
    bs: int = None,
    sentence_piece: bool = False,
    subwords: bool = False,
    device: int = 0,
    cpu: bool = False,
    output: str = "test.tsv",
):
    assert saved and output
    tf.random.set_seed(0)
    tf.keras.backend.clear_session()
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": mxp})
    env_util.setup_devices([device], cpu=cpu)

    config = Config(config)

    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    conformer = Conformer(**config.model_config,
                          vocabulary_size=text_featurizer.num_classes)
    conformer.make(speech_featurizer.shape)
    conformer.load_weights(saved, by_name=True)
    conformer.summary(line_length=100)
    conformer.add_featurizers(speech_featurizer, text_featurizer)

    test_dataset = dataset_helpers.prepare_testing_datasets(
        config=config,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer)
    batch_size = bs or config.learning_config.running_config.batch_size
    test_data_loader = test_dataset.create(batch_size)

    exec_helpers.run_testing(model=conformer,
                             test_dataset=test_dataset,
                             test_data_loader=test_data_loader,
                             output=output)
Beispiel #4
0
        config.decoder_config, args.subwords)
elif args.subwords and os.path.exists(args.subwords):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config,
                                                       args.subwords)
else:
    text_featurizer = CharFeaturizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
conformer = Conformer(**config.model_config,
                      vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
features = speech_featurizer.tf_extract(signal)
input_length = math_util.get_reduced_length(
    tf.shape(features)[0], conformer.time_reduction_factor)

if args.beam_width:
    transcript = conformer.recognize_beam(features[None, ...],
                                          input_length[None, ...])
    print("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
    transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
        signal, tf.constant(text_featurizer.blank, dtype=tf.int32),
        conformer.predict_net.get_initial_state())
    print("Transcript:", transcript)
Beispiel #5
0
def test_conformer():
    config = Config(DEFAULT_YAML)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = Conformer(vocabulary_size=text_featurizer.num_classes,
                      **config.model_config)

    model.make(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(speech_featurizer=speech_featurizer,
                          text_featurizer=text_featurizer)

    concrete_func = model.make_tflite_function(
        timestamp=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite = converter.convert()

    print("Converted successfully with no timestamp")

    concrete_func = model.make_tflite_function(
        timestamp=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.convert()

    print("Converted successfully with timestamp")

    tflitemodel = tf.lite.Interpreter(model_content=tflite)
    signal = tf.random.normal([4000])

    input_details = tflitemodel.get_input_details()
    output_details = tflitemodel.get_output_details()
    tflitemodel.resize_tensor_input(input_details[0]["index"], [4000])
    tflitemodel.allocate_tensors()
    tflitemodel.set_tensor(input_details[0]["index"], signal)
    tflitemodel.set_tensor(input_details[1]["index"],
                           tf.constant(text_featurizer.blank, dtype=tf.int32))
    tflitemodel.set_tensor(
        input_details[2]["index"],
        tf.zeros([
            config.model_config["prediction_num_rnns"], 2, 1,
            config.model_config["prediction_rnn_units"]
        ],
                 dtype=tf.float32))
    tflitemodel.invoke()
    hyp = tflitemodel.get_tensor(output_details[0]["index"])

    print(hyp)