示例#1
0
    def test_forward_attention(
        self,
    ):
        (
            chars_seq,
            chars_seq_lengths,
            mel_spec,
            mel_postnet_spec,
            mel_lengths,
            stop_targets,
            speaker_ids,
        ) = self.generate_dummy_inputs()

        for idx in mel_lengths:
            stop_targets[:, int(idx.item()) :, 0] = 1.0

        stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()

        model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True)
        # training pass
        output = model(chars_seq, chars_seq_lengths, mel_spec, training=True)

        # check model output shapes
        assert np.all(output[0].shape == mel_spec.shape)
        assert np.all(output[1].shape == mel_spec.shape)
        assert output[2].shape[2] == chars_seq.shape[1]
        assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
        assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)

        # inference pass
        output = model(chars_seq, training=False)
示例#2
0
 def test_tflite_conversion(
     self,
 ):  # pylint:disable=no-self-use
     model = Tacotron2(
         num_chars=24,
         num_speakers=0,
         r=3,
         out_channels=80,
         decoder_output_dim=80,
         attn_type="original",
         attn_win=False,
         attn_norm="sigmoid",
         prenet_type="original",
         prenet_dropout=True,
         forward_attn=False,
         trans_agent=False,
         forward_attn_mask=False,
         location_attn=True,
         attn_K=0,
         separate_stopnet=True,
         bidirectional_decoder=False,
         enable_tflite=True,
     )
     model.build_inference()
     convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True)
     # init tflite model
     tflite_model = load_tflite_model("test_tacotron2.tflite")
     # fake input
     inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)  # pylint:disable=unexpected-keyword-arg
     # run inference
     # get input and output details
     input_details = tflite_model.get_input_details()
     output_details = tflite_model.get_output_details()
     # reshape input tensor for the new input shape
     tflite_model.resize_tensor_input(
         input_details[0]["index"], inputs.shape
     )  # pylint:disable=unexpected-keyword-arg
     tflite_model.allocate_tensors()
     detail = input_details[0]
     input_shape = detail["shape"]
     tflite_model.set_tensor(detail["index"], inputs)
     # run the tflite_model
     tflite_model.invoke()
     # collect outputs
     decoder_output = tflite_model.get_tensor(output_details[0]["index"])
     postnet_output = tflite_model.get_tensor(output_details[1]["index"])
     # remove tflite binary
     os.remove("test_tacotron2.tflite")
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path,
                        map_location=torch.device('cpu'))
state_dict = checkpoint['model']
model.load_state_dict(state_dict)

# init tf model
model_tf = Tacotron2(num_chars=num_chars,
                     num_speakers=num_speakers,
                     r=model.decoder.r,
                     postnet_output_dim=c.audio['num_mels'],
                     decoder_output_dim=c.audio['num_mels'],
                     attn_type=c.attention_type,
                     attn_win=c.windowing,
                     attn_norm=c.attention_norm,
                     prenet_type=c.prenet_type,
                     prenet_dropout=c.prenet_dropout,
                     forward_attn=c.use_forward_attn,
                     trans_agent=c.transition_agent,
                     forward_attn_mask=c.forward_attn_mask,
                     location_attn=c.location_attn,
                     attn_K=c.attention_heads,
                     separate_stopnet=c.separate_stopnet,
                     bidirectional_decoder=c.bidirectional_decoder)

# set initial layer mapping - these are not captured by the below heuristic approach
# TODO: set layer names so that we can remove these manual matching
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
var_map = [
    ('embedding/embeddings:0', 'embedding.weight'),
    ('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',