def test_contextnet(): config = Config(DEFAULT_YAML, learning=False) text_featurizer = CharFeaturizer(config.decoder_config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) model = ContextNet(vocabulary_size=text_featurizer.num_classes, **config.model_config) model._build(speech_featurizer.shape) model.summary(line_length=150) model.add_featurizers( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer ) concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] converter.convert() print("Converted successfully with no timestamp") concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] converter.convert() print("Converted successfully with timestamp")
parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported") args = parser.parse_args() assert args.saved and args.output config = Config(args.config, learning=True) speech_featurizer = TFSpeechFeaturizer(config.speech_config) text_featurizer = CharFeaturizer(config.decoder_config) # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) contextnet.load_weights(args.saved) contextnet.summary(line_length=150) contextnet.add_featurizers(speech_featurizer, text_featurizer) concrete_func = contextnet.make_tflite_function( greedy=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite_model = converter.convert() if not os.path.exists(os.path.dirname(args.output)):
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.train_dataset_config) ) eval_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) contextnet_trainer = TransducerTrainerGA( config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) with contextnet_trainer.strategy.scope(): # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) contextnet.summary(line_length=120) optimizer = tf.keras.optimizers.Adam( TransformerSchedule( d_model=contextnet.dmodel, warmup_steps=config.learning_config.optimizer_config["warmup_steps"], max_lr=(0.05 / math.sqrt(contextnet.dmodel)) ), beta_1=config.learning_config.optimizer_config["beta1"], beta_2=config.learning_config.optimizer_config["beta2"], epsilon=config.learning_config.optimizer_config["epsilon"] ) contextnet_trainer.compile(model=contextnet, optimizer=optimizer,
args.subwords) else: raise ValueError("subwords must be set") tf.random.set_seed(0) assert args.saved if args.tfrecords: test_dataset = ASRTFRecordDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.test_dataset_config)) else: test_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.test_dataset_config)) # build model contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) contextnet._build(speech_featurizer.shape) contextnet.load_weights(args.saved) contextnet.summary(line_length=120) contextnet.add_featurizers(speech_featurizer, text_featurizer) contextnet_tester = BaseTester(config=config.learning_config.running_config, output_name=args.output_name) contextnet_tester.compile(contextnet) contextnet_tester.run(test_dataset)
def test_contextnet(): config = Config(DEFAULT_YAML, learning=False) text_featurizer = CharFeaturizer(config.decoder_config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) model = ContextNet(vocabulary_size=text_featurizer.num_classes, **config.model_config) model._build(speech_featurizer.shape) model.summary(line_length=150) model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) concrete_func = model.make_tflite_function( timestamp=False).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] tflite = converter.convert() print("Converted successfully with no timestamp") concrete_func = model.make_tflite_function( timestamp=True).get_concrete_function() converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.experimental_new_converter = True converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ] converter.convert() print("Converted successfully with timestamp") tflitemodel = tf.lite.Interpreter(model_content=tflite) signal = tf.random.normal([4000]) input_details = tflitemodel.get_input_details() output_details = tflitemodel.get_output_details() tflitemodel.resize_tensor_input(input_details[0]["index"], [4000]) tflitemodel.allocate_tensors() tflitemodel.set_tensor(input_details[0]["index"], signal) tflitemodel.set_tensor(input_details[1]["index"], tf.constant(text_featurizer.blank, dtype=tf.int32)) tflitemodel.set_tensor( input_details[2]["index"], tf.zeros([ config.model_config["prediction_num_rnns"], 2, 1, config.model_config["prediction_rnn_units"] ], dtype=tf.float32)) tflitemodel.invoke() hyp = tflitemodel.get_tensor(output_details[0]["index"]) print(hyp)