assert args.saved if args.tfrecords: test_dataset = ASRTFRecordTestDataset( data_paths=config.learning_config.dataset_config.test_paths, tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False) else: test_dataset = ASRSliceTestDataset( data_paths=config.learning_config.dataset_config.test_paths, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False) # build model streaming_transducer = StreamingTransducer( vocabulary_size=text_featurizer.num_classes, **config.model_config) streaming_transducer._build(speech_featurizer.shape) streaming_transducer.load_weights(args.saved, by_name=True) streaming_transducer.summary(line_length=150) streaming_transducer.add_featurizers(speech_featurizer, text_featurizer) streaming_transducer_tester = BaseTester( config=config.learning_config.running_config, output_name=args.output_name) streaming_transducer_tester.compile(streaming_transducer) streaming_transducer_tester.run(test_dataset)
args = parser.parse_args() assert args.saved and args.output config = UserConfig(DEFAULT_YAML, args.config, learning=True) speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) text_featurizer = CharFeaturizer(config["decoder_config"]) # build model streaming_transducer = StreamingTransducer( **config["model_config"], vocabulary_size=text_featurizer.num_classes ) streaming_transducer._build(speech_featurizer.shape) streaming_transducer.load_weights(args.saved) streaming_transducer.summary(line_length=150) streaming_transducer.add_featurizers(speech_featurizer, text_featurizer) concrete_func = streaming_transducer.make_tflite_function(greedy=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() if not os.path.exists(os.path.dirname(args.output)): os.makedirs(os.path.dirname(args.output)) with open(args.output, "wb") as tflite_out: tflite_out.write(tflite_model)