예제 #1
0
def main(
    config: str = DEFAULT_YAML,
    h5: str = None,
    subwords: bool = False,
    sentence_piece: bool = False,
    output: str = None,
):
    assert h5 and output
    tf.keras.backend.clear_session()
    tf.compat.v1.enable_control_flow_v2()

    config = Config(config)
    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    rnn_transducer = RnnTransducer(**config.model_config,
                                   vocabulary_size=text_featurizer.num_classes)
    rnn_transducer.make(speech_featurizer.shape)
    rnn_transducer.load_weights(h5, by_name=True)
    rnn_transducer.summary(line_length=100)
    rnn_transducer.add_featurizers(speech_featurizer, text_featurizer)

    exec_helpers.convert_tflite(model=rnn_transducer, output=output)
예제 #2
0
def main(
    config: str = DEFAULT_YAML,
    saved: str = None,
    mxp: bool = False,
    bs: int = None,
    sentence_piece: bool = False,
    subwords: bool = False,
    device: int = 0,
    cpu: bool = False,
    output: str = "test.tsv",
):
    assert saved and output
    tf.random.set_seed(0)
    tf.keras.backend.clear_session()
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": mxp})
    env_util.setup_devices([device], cpu=cpu)

    config = Config(config)

    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    rnn_transducer = RnnTransducer(**config.model_config,
                                   vocabulary_size=text_featurizer.num_classes)
    rnn_transducer.make(speech_featurizer.shape)
    rnn_transducer.load_weights(saved, by_name=True)
    rnn_transducer.summary(line_length=100)
    rnn_transducer.add_featurizers(speech_featurizer, text_featurizer)

    test_dataset = dataset_helpers.prepare_testing_datasets(
        config=config,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer)
    batch_size = bs or config.learning_config.running_config.batch_size
    test_data_loader = test_dataset.create(batch_size)

    exec_helpers.run_testing(model=rnn_transducer,
                             test_dataset=test_dataset,
                             test_data_loader=test_data_loader,
                             output=output)
예제 #3
0
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
if args.sentence_piece:
    logger.info("Loading SentencePiece model ...")
    text_featurizer = SentencePieceFeaturizer(config.decoder_config)
elif args.subwords:
    logger.info("Loading subwords ...")
    text_featurizer = SubwordFeaturizer(config.decoder_config)
else:
    text_featurizer = CharFeaturizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
rnnt = RnnTransducer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
rnnt.make(speech_featurizer.shape)
rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True)
rnnt.summary(line_length=120)
rnnt.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
features = speech_featurizer.tf_extract(signal)
input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor)

if args.beam_width:
    transcript = rnnt.recognize_beam(
        data_util.create_inputs(
            inputs=features[None, ...],
            inputs_length=input_length[None, ...]
        )
    )
    logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
예제 #4
0
    logger.info("Use characters ...")
    text_featurizer = CharFeaturizer(config.decoder_config)

tf.random.set_seed(0)

test_dataset = ASRSliceDataset(speech_featurizer=speech_featurizer,
                               text_featurizer=text_featurizer,
                               **vars(
                                   config.learning_config.test_dataset_config))

# build model
rnn_transducer = RnnTransducer(**config.model_config,
                               vocabulary_size=text_featurizer.num_classes)
rnn_transducer.make(speech_featurizer.shape)
rnn_transducer.load_weights(args.saved, by_name=True)
rnn_transducer.summary(line_length=100)
rnn_transducer.add_featurizers(speech_featurizer, text_featurizer)

batch_size = args.bs or config.learning_config.running_config.batch_size
test_data_loader = test_dataset.create(batch_size)

with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath:
    overwrite = True
    if tf.io.gfile.exists(filepath):
        overwrite = input(
            f"Overwrite existing result file {filepath} ? (y/n): ").lower(
            ) == "y"
    if overwrite:
        results = rnn_transducer.predict(test_data_loader, verbose=1)
        logger.info(f"Saving result to {args.output} ...")
        with open(filepath, "w") as openfile:
예제 #5
0
def test_streaming_transducer():
    config = Config(DEFAULT_YAML)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = RnnTransducer(vocabulary_size=text_featurizer.num_classes,
                          **config.model_config)

    model.make(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(speech_featurizer=speech_featurizer,
                          text_featurizer=text_featurizer)

    concrete_func = model.make_tflite_function(
        timestamp=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite_model = converter.convert()

    print("Converted successfully with no timestamp")

    concrete_func = model.make_tflite_function(
        timestamp=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.convert()

    print("Converted successfully with timestamp")

    tflitemodel = tf.lite.Interpreter(model_content=tflite_model)
    signal = tf.random.normal([4000])

    input_details = tflitemodel.get_input_details()
    output_details = tflitemodel.get_output_details()
    tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
    tflitemodel.allocate_tensors()
    tflitemodel.set_tensor(input_details[0]["index"], signal)
    tflitemodel.set_tensor(input_details[1]["index"],
                           tf.constant(text_featurizer.blank, dtype=tf.int32))
    tflitemodel.set_tensor(
        input_details[2]["index"],
        tf.zeros([
            config.model_config["encoder_nlayers"], 2, 1,
            config.model_config["encoder_rnn_units"]
        ],
                 dtype=tf.float32))
    tflitemodel.set_tensor(
        input_details[3]["index"],
        tf.zeros([
            config.model_config["prediction_num_rnns"], 2, 1,
            config.model_config["prediction_rnn_units"]
        ],
                 dtype=tf.float32))
    tflitemodel.invoke()
    hyp = tflitemodel.get_tensor(output_details[0]["index"])

    print(hyp)
예제 #6
0
def main(
    config: str = DEFAULT_YAML,
    tfrecords: bool = False,
    sentence_piece: bool = False,
    subwords: bool = True,
    bs: int = None,
    spx: int = 1,
    metadata: str = None,
    static_length: bool = False,
    devices: list = [0],
    mxp: bool = False,
    pretrained: str = None,
):
    tf.keras.backend.clear_session()
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": mxp})
    strategy = env_util.setup_strategy(devices)

    config = Config(config)

    speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
        config=config,
        subwords=subwords,
        sentence_piece=sentence_piece,
    )

    train_dataset, eval_dataset = dataset_helpers.prepare_training_datasets(
        config=config,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        tfrecords=tfrecords,
        metadata=metadata,
    )

    if not static_length:
        speech_featurizer.reset_length()
        text_featurizer.reset_length()

    train_data_loader, eval_data_loader, global_batch_size = dataset_helpers.prepare_training_data_loaders(
        config=config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        strategy=strategy,
        batch_size=bs,
    )

    with strategy.scope():
        rnn_transducer = RnnTransducer(
            **config.model_config, vocabulary_size=text_featurizer.num_classes)
        rnn_transducer.make(speech_featurizer.shape,
                            prediction_shape=text_featurizer.prepand_shape,
                            batch_size=global_batch_size)
        if pretrained:
            rnn_transducer.load_weights(pretrained,
                                        by_name=True,
                                        skip_mismatch=True)
        rnn_transducer.summary(line_length=100)
        rnn_transducer.compile(
            optimizer=config.learning_config.optimizer_config,
            experimental_steps_per_execution=spx,
            global_batch_size=global_batch_size,
            blank=text_featurizer.blank,
        )

    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            **config.learning_config.running_config.checkpoint),
        tf.keras.callbacks.experimental.BackupAndRestore(
            config.learning_config.running_config.states_dir),
        tf.keras.callbacks.TensorBoard(
            **config.learning_config.running_config.tensorboard),
    ]

    rnn_transducer.fit(
        train_data_loader,
        epochs=config.learning_config.running_config.num_epochs,
        validation_data=eval_data_loader,
        callbacks=callbacks,
        steps_per_epoch=train_dataset.total_steps,
        validation_steps=eval_dataset.total_steps
        if eval_data_loader else None,
    )