def test_contextnet(): config = Config(DEFAULT_YAML, learning=False) text_featurizer = CharFeaturizer(config.decoder_config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) model = ContextNet(vocabulary_size=text_featurizer.num_classes, **config.model_config) model._build(speech_featurizer.shape) model.summary(line_length=150) model.add_featurizers( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer ) concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] converter.convert() print("Converted successfully with no timestamp") concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] converter.convert() print("Converted successfully with timestamp")
args = parser.parse_args() assert args.saved and args.output config = Config(args.config, learning=True) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) contextnet.load_weights(args.saved) contextnet.summary(line_length=150) contextnet.add_featurizers(speech_featurizer, text_featurizer) concrete_func = contextnet.make_tflite_function( greedy=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite_model = converter.convert() if not os.path.exists(os.path.dirname(args.output)): os.makedirs(os.path.dirname(args.output)) with open(args.output, "wb") as tflite_out: tflite_out.write(tflite_model)
def test_contextnet(): config = Config(DEFAULT_YAML, learning=False) text_featurizer = CharFeaturizer(config.decoder_config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) model = ContextNet(vocabulary_size=text_featurizer.num_classes, **config.model_config) model._build(speech_featurizer.shape) model.summary(line_length=150) model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) concrete_func = model.make_tflite_function( timestamp=False).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite = converter.convert() print("Converted successfully with no timestamp") concrete_func = model.make_tflite_function( timestamp=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] converter.convert() print("Converted successfully with timestamp") tflitemodel = tf.lite.Interpreter(model_content=tflite) signal = tf.random.normal([4000]) input_details = tflitemodel.get_input_details() output_details = tflitemodel.get_output_details() tflitemodel.resize_tensor_input(input_details[0]["index"], [4000]) tflitemodel.allocate_tensors() tflitemodel.set_tensor(input_details[0]["index"], signal) tflitemodel.set_tensor(input_details[1]["index"], tf.constant(text_featurizer.blank, dtype=tf.int32)) tflitemodel.set_tensor( input_details[2]["index"], tf.zeros([ config.model_config["prediction_num_rnns"], 2, 1, config.model_config["prediction_rnn_units"] ], dtype=tf.float32)) tflitemodel.invoke() hyp = tflitemodel.get_tensor(output_details[0]["index"]) print(hyp)