default=None, help="TFLite file path to be exported") args = parser.parse_args() assert args.saved and args.output config = Config(args.config, learning=True) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) contextnet.load_weights(args.saved) contextnet.summary(line_length=150) contextnet.add_featurizers(speech_featurizer, text_featurizer) concrete_func = contextnet.make_tflite_function( greedy=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite_model = converter.convert() if not os.path.exists(os.path.dirname(args.output)): os.makedirs(os.path.dirname(args.output)) with open(args.output, "wb") as tflite_out:
config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) if args.subwords and os.path.exists(args.subwords): print("Loading subwords ...") text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) else: raise ValueError("subwords must be set") # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) contextnet.load_weights(args.saved, by_name=True) contextnet.summary(line_length=150) contextnet.add_featurizers(speech_featurizer, text_featurizer) concrete_func = contextnet.make_tflite_function( greedy=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite_model = converter.convert() if not os.path.exists(os.path.dirname(args.output)): os.makedirs(os.path.dirname(args.output)) with open(args.output, "wb") as tflite_out: