Beispiel #1
0
    def test_fixed_attention_parallel_decoder_inference(self):
        encoder_dim = 12
        vae_dim = 6
        audio_encoder_dim = 6
        n_frames_per_step = 5
        p_teacher_forcing = 0.5
        decoder_dim = 15

        model_config = enc_dec_dyn.Config(modules=[
            self._get_encoder_config(encoder_dim),
            self._get_encoder_embedding_config(vae_dim),
            self._get_fixed_attention_config(),
            self._get_parallel_decoder_config(encoder_dim +
                                              vae_dim, decoder_dim),
        ])
        model = model_config.create_model()

        phoneme_seq_length = torch.tensor((10, 12), dtype=torch.long)
        phoneme_max_length = torch.tensor(12, dtype=torch.long)
        seq_length = torch.tensor((100, 75), dtype=torch.long)
        max_length = torch.tensor(100, dtype=torch.long)
        batch_size = 2

        test_input = {}
        test_input["emb_idx"] = torch.zeros([batch_size, 1]).long()
        test_input["phonemes"] = torch.ones(
            [batch_size, phoneme_seq_length.max(), 1]).long()
        test_input["attention_matrix"] = torch.zeros(
            (batch_size, max_length, phoneme_max_length))
        seq_length_dict = {"phonemes": phoneme_seq_length, "emb_idx": 1}
        seq_length_dict["attention_matrix"] = seq_length
        max_length_dict = {"phonemes": phoneme_max_length, "emb_idx": 1}
        max_length_dict["attention_matrix"] = max_length
        org_test_input = copy.deepcopy(test_input)

        model.init_hidden(batch_size)
        output = model.inference(test_input, seq_length_dict, max_length_dict)

        output_filtered = {
            k: v
            for k, v in output.items() if k not in org_test_input
        }

        for key in output_filtered.keys():
            self.assertIn(key,
                          seq_length_dict,
                          msg="{} not found in seq_length_dict".format(key))
            self.assertIn(key,
                          max_length_dict,
                          msg="{} not found in max_length_dict".format(key))
        self.assertEqual(
            torch.Size([batch_size, seq_length.max(), decoder_dim]),
            output["pred_acoustic_features"].shape)
Beispiel #2
0
    def test_fixed_attention_batched_b1(self):
        encoder_dim = 12
        vae_dim = 6
        audio_encoder_dim = 6
        n_frames_per_step = 5
        p_teacher_forcing = 1
        decoder_dim = 15

        model_config = enc_dec_dyn.Config(modules=[
            self._get_encoder_config(encoder_dim),
            self._get_encoder_vae_pool_last_config(decoder_dim, vae_dim),
            self._get_fixed_attention_decoder_config(
                audio_encoder_dim, encoder_dim + audio_encoder_dim +
                vae_dim, decoder_dim, n_frames_per_step, p_teacher_forcing),
            self._get_postnet_config(decoder_dim)
        ])
        model = model_config.create_model()

        phoneme_seq_length = torch.tensor((10, ), dtype=torch.long)
        phoneme_max_length = torch.tensor(10, dtype=torch.long)
        seq_length = torch.tensor((100, ), dtype=torch.long)
        max_length = torch.tensor(100, dtype=torch.long)
        batch_size = 1

        test_input = {}
        test_input["phonemes"] = torch.ones(
            [batch_size, phoneme_seq_length.max(), 1]).long()
        test_input["acoustic_features"] = torch.ones(
            [batch_size, seq_length.max(), decoder_dim])
        test_input["attention_matrix"] = torch.zeros(
            (batch_size, max_length, phoneme_max_length))
        seq_length_dict = {
            "phonemes": phoneme_seq_length,
            "acoustic_features": seq_length,
            "attention_matrix": seq_length
        }
        max_length_dict = {
            "phonemes": phoneme_max_length,
            "acoustic_features": max_length,
            "attention_matrix": max_length
        }

        model.init_hidden(batch_size)
        output = model(test_input, seq_length_dict, max_length_dict)

        self.assertEqual(
            torch.Size([batch_size, seq_length.max(), decoder_dim]),
            output["pred_acoustic_features"].shape)
Beispiel #3
0
    def test_save_load(self):
        encoder_dim = 12
        vae_dim = 6
        decoder_dim = 15

        def ordered(obj):
            if isinstance(obj, dict):
                return sorted((k, ordered(v)) for k, v in obj.items())
            if isinstance(obj, list):
                return sorted(ordered(x) for x in obj)
            else:
                return obj

        model_config = enc_dec_dyn.Config(modules=[
            self._get_encoder_config(encoder_dim),
            self._get_encoder_vae_config(decoder_dim, vae_dim),
            self._get_fixed_attention_config(),
            self._get_parallel_decoder_config(encoder_dim +
                                              vae_dim, decoder_dim)
        ])
        model = model_config.create_model()

        other_model = model_config.create_model()
        self.assertTrue((list(model.parameters())[15] != list(
            other_model.parameters())[15]).any())

        config_json = model.get_config_as_json()
        # out_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type(self).__name__, "test_save_load")
        # makedirs_safe(out_dir)
        # with open(os.path.join(out_dir, "model_config.json"), "w") as f:
        #     f.write(jsonpickle.encode(model_config, indent=4))
        # with open(os.path.join(out_dir, "config.json"), "w") as f:
        #     f.write(config_json)
        self.assertEqual(ordered(jsonpickle.encode(model_config, indent=4)),
                         ordered(config_json))
        params = model.state_dict()
        recreated_config = jsonpickle.decode(config_json)
        recreated_model = recreated_config.create_model()
        recreated_model.load_state_dict(params)

        self.assertTrue((list(model.parameters())[0] == list(
            recreated_model.parameters())[0]).all())
Beispiel #4
0
    def test_fixed_attention_parallel_decoder(self):
        encoder_dim = 12
        vae_dim = 6
        audio_encoder_dim = 6
        n_frames_per_step = 5
        p_teacher_forcing = 1
        decoder_dim = 15

        model_config = enc_dec_dyn.Config(modules=[
            self._get_encoder_config(encoder_dim),
            self._get_encoder_vae_config(decoder_dim, vae_dim),
            self._get_fixed_attention_config(),
            self._get_parallel_decoder_config(encoder_dim +
                                              vae_dim, decoder_dim)
        ])
        model = model_config.create_model()

        phoneme_seq_length = torch.tensor((10, 12), dtype=torch.long)
        phoneme_max_length = torch.tensor(12, dtype=torch.long)
        seq_length = torch.tensor((100, 75), dtype=torch.long)
        max_length = torch.tensor(100, dtype=torch.long)
        batch_size = 2

        test_input = {}
        test_input["phonemes"] = torch.ones(
            [batch_size, phoneme_seq_length.max(), 1]).long()
        test_input["acoustic_features"] = torch.ones(
            [batch_size, seq_length.max(), decoder_dim])
        test_input["attention_matrix"] = torch.zeros(
            (batch_size, max_length, phoneme_max_length))
        seq_length_dict = {
            "phonemes": phoneme_seq_length,
            "acoustic_features": seq_length,
            "attention_matrix": seq_length
        }
        max_length_dict = {
            "phonemes": phoneme_max_length,
            "acoustic_features": max_length,
            "attention_matrix": max_length
        }

        model.init_hidden(batch_size)
        output = model(test_input, seq_length_dict, max_length_dict)

        self.assertEqual(
            torch.Size([batch_size,
                        phoneme_seq_length.max(), encoder_dim]),
            output["phoneme_embeddings"].shape)
        self.assertEqual(torch.Size([batch_size,
                                     seq_length.max(), vae_dim]),
                         output["emb_mu"].shape)
        self.assertEqual(
            torch.Size([batch_size, seq_length.max(), decoder_dim]),
            output["pred_acoustic_features"].shape)
        self.assertTrue((
            phoneme_seq_length == seq_length_dict["phoneme_embeddings"]).any())
        self.assertTrue((
            phoneme_max_length == max_length_dict["phoneme_embeddings"]).any())
        self.assertTrue((seq_length == max_length_dict["emb_mu"]).any())
        self.assertTrue(
            (seq_length == seq_length_dict["pred_acoustic_features"]).any())

        expected_params = 1 + 2 * 3  # Phoneme encoder: 1 Emb + 3 Conv (weight & bias)
        expected_params += 2 * 3 + 4 + 1  # Acoustic encoder: 3 Conv + GRU + VAE projection
        expected_params += 2 + 4  # Parallel decoder: Linear + LSTM
        self.assertEqual(expected_params, len([*model.named_parameters()]))
        output["pred_acoustic_features"].sum().backward()