コード例 #1
0
assert args.saved

if args.tfrecords:
    test_dataset = ASRTFRecordTestDataset(
        data_paths=config.learning_config.dataset_config.test_paths,
        tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test",
        shuffle=False)
else:
    test_dataset = ASRSliceTestDataset(
        data_paths=config.learning_config.dataset_config.test_paths,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test",
        shuffle=False)

# build model
streaming_transducer = StreamingTransducer(
    vocabulary_size=text_featurizer.num_classes, **config.model_config)
streaming_transducer._build(speech_featurizer.shape)
streaming_transducer.load_weights(args.saved, by_name=True)
streaming_transducer.summary(line_length=150)
streaming_transducer.add_featurizers(speech_featurizer, text_featurizer)

streaming_transducer_tester = BaseTester(
    config=config.learning_config.running_config, output_name=args.output_name)
streaming_transducer_tester.compile(streaming_transducer)
streaming_transducer_tester.run(test_dataset)
コード例 #2
0
args = parser.parse_args()

assert args.saved and args.output

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])

# build model
streaming_transducer = StreamingTransducer(
    **config["model_config"],
    vocabulary_size=text_featurizer.num_classes
)
streaming_transducer._build(speech_featurizer.shape)
streaming_transducer.load_weights(args.saved)
streaming_transducer.summary(line_length=150)
streaming_transducer.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = streaming_transducer.make_tflite_function(greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                       tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
    os.makedirs(os.path.dirname(args.output))
with open(args.output, "wb") as tflite_out:
    tflite_out.write(tflite_model)