コード例 #1
0
 def test_tflite_conversion(
     self,
 ):  # pylint:disable=no-self-use
     model = Tacotron2(
         num_chars=24,
         num_speakers=0,
         r=3,
         out_channels=80,
         decoder_output_dim=80,
         attn_type="original",
         attn_win=False,
         attn_norm="sigmoid",
         prenet_type="original",
         prenet_dropout=True,
         forward_attn=False,
         trans_agent=False,
         forward_attn_mask=False,
         location_attn=True,
         attn_K=0,
         separate_stopnet=True,
         bidirectional_decoder=False,
         enable_tflite=True,
     )
     model.build_inference()
     convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True)
     # init tflite model
     tflite_model = load_tflite_model("test_tacotron2.tflite")
     # fake input
     inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)  # pylint:disable=unexpected-keyword-arg
     # run inference
     # get input and output details
     input_details = tflite_model.get_input_details()
     output_details = tflite_model.get_output_details()
     # reshape input tensor for the new input shape
     tflite_model.resize_tensor_input(
         input_details[0]["index"], inputs.shape
     )  # pylint:disable=unexpected-keyword-arg
     tflite_model.allocate_tensors()
     detail = input_details[0]
     input_shape = detail["shape"]
     tflite_model.set_tensor(detail["index"], inputs)
     # run the tflite_model
     tflite_model.invoke()
     # collect outputs
     decoder_output = tflite_model.get_tensor(output_details[0]["index"])
     postnet_output = tflite_model.get_tensor(output_details[1]["index"])
     # remove tflite binary
     os.remove("test_tacotron2.tflite")
コード例 #2
0
from TTS.tts.tf.utils.generic_utils import setup_model
from TTS.tts.tf.utils.io import load_checkpoint
from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite

parser = argparse.ArgumentParser()
parser.add_argument('--tf_model',
                    type=str,
                    help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
                    type=str,
                    help='Path to config file of torch model.')
parser.add_argument('--output_path',
                    type=str,
                    help='path to tflite output binary.')
args = parser.parse_args()

# Set constants
CONFIG = load_config(args.config_path)

# load the model
c = CONFIG
num_speakers = 0
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
model.build_inference()
model = load_checkpoint(model, args.tf_model)
model.decoder.set_max_decoder_steps(1000)

# create tflite model
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)