def test_jasper():
    config = Config(DEFAULT_YAML)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = Jasper(vocabulary_size=text_featurizer.num_classes,
                   **config.model_config)

    model._build(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(speech_featurizer=speech_featurizer,
                          text_featurizer=text_featurizer)

    concrete_func = model.make_tflite_function(
        greedy=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.convert()

    print("Converted successfully with beam search")

    concrete_func = model.make_tflite_function(
        greedy=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.convert()

    print("Converted successfully with greedy")
Exemple #2
0
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))
else:
    train_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config))
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))

ctc_trainer = CTCTrainer(text_featurizer,
                         config.learning_config.running_config)
# Build DS2 model
with ctc_trainer.strategy.scope():
    jasper = Jasper(**config.model_config,
                    vocabulary_size=text_featurizer.num_classes)
    jasper._build(speech_featurizer.shape)
    jasper.summary(line_length=120)
# Compile
ctc_trainer.compile(jasper,
                    config.learning_config.optimizer_config,
                    max_to_keep=args.max_ckpts)

ctc_trainer.fit(train_dataset,
                eval_dataset,
                train_bs=args.tbs,
                eval_bs=args.ebs)