コード例 #1
0
                    default=None,
                    help="TFLite file path to be exported")

args = parser.parse_args()

assert args.saved and args.output

config = Config(args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

# build model
contextnet = ContextNet(**config.model_config,
                        vocabulary_size=text_featurizer.num_classes)
contextnet._build(speech_featurizer.shape)
contextnet.load_weights(args.saved)
contextnet.summary(line_length=150)
contextnet.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = contextnet.make_tflite_function(
    greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
    os.makedirs(os.path.dirname(args.output))
with open(args.output, "wb") as tflite_out:
コード例 #2
0
config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.subwords and os.path.exists(args.subwords):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config,
                                                       args.subwords)
else:
    raise ValueError("subwords must be set")

# build model
contextnet = ContextNet(**config.model_config,
                        vocabulary_size=text_featurizer.num_classes)
contextnet._build(speech_featurizer.shape)
contextnet.load_weights(args.saved, by_name=True)
contextnet.summary(line_length=150)
contextnet.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = contextnet.make_tflite_function(
    greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
    os.makedirs(os.path.dirname(args.output))
with open(args.output, "wb") as tflite_out: