Exemplo n.º 1
0
    def run(args):
        config = UserConfig(DEFAULT_YAML, args.config, learning=True)
        speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
        text_featurizer = TextFeaturizer(config["decoder_config"])

        tf.random.set_seed(0)
        assert args.saved_model

        if args.tfrecords:
            test_dataset = ASRTFRecordDataset(
                config["learning_config"]["dataset_config"]["test_paths"],
                config["learning_config"]["dataset_config"]["tfrecords_dir"],
                speech_featurizer,
                text_featurizer,
                "test",
                augmentations=config["learning_config"]["augmentations"],
                shuffle=False).create(
                    config["learning_config"]["running_config"]["batch_size"])
        else:
            test_dataset = ASRSliceDataset(
                stage="test",
                speech_featurizer=speech_featurizer,
                text_featurizer=text_featurizer,
                data_paths=config["learning_config"]["dataset_config"]
                ["eval_paths"],
                shuffle=False).create(
                    config["learning_config"]["running_config"]["batch_size"])

        # build model
        f, c = speech_featurizer.compute_feature_dim()
        conformer = Conformer(vocabulary_size=text_featurizer.num_classes,
                              **config["model_config"])
        conformer._build([1, 50, f, c])
        conformer.summary(line_length=100)

        conformer_tester = BaseTester(
            config=config["learning_config"]["running_config"],
            saved_path=args.saved_model,
            from_weights=args.from_weights)
        conformer_tester.compile(conformer, speech_featurizer, text_featurizer)
        conformer_tester.run(test_dataset)
Exemplo n.º 2
0
                    help="TFLite file path to be exported")

args = parser.parse_args()

assert args.saved and args.output

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])

# build model
conformer = Conformer(**config["model_config"],
                      vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved)
conformer.summary(line_length=150)
conformer.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = conformer.make_tflite_function(
    greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
    os.makedirs(os.path.dirname(args.output))
with open(args.output, "wb") as tflite_out:
    tflite_out.write(tflite_model)