def test_get_model_symbols_count_uses_const(self):
        ref = get_model_symbols_count(n_symbols=10,
                                      n_accents=5,
                                      accents_use_own_symbols=True,
                                      shared_symbol_count=SHARED_SYMBOLS_COUNT)

        res = get_model_symbols_count(n_symbols=10,
                                      n_accents=5,
                                      accents_use_own_symbols=True)

        self.assertEqual(ref, res)
예제 #2
0
    def __init__(self, hparams: HParams, logger: Logger):
        super(Tacotron2, self).__init__()
        model_symbols_count = get_model_symbols_count(
            hparams.n_symbols, hparams.n_accents,
            hparams.accents_use_own_symbols)

        if hparams.accents_use_own_symbols:
            logger.info(
                f"All {hparams.n_accents} accent(s) use an own symbolset. The number of total symbols increased from {hparams.n_symbols} to {model_symbols_count}."
            )
        else:
            logger.info(
                f"All {hparams.n_accents} accent(s) share the same symbolset. The number of total symbols is {model_symbols_count}."
            )

        self.logger = logger
        self.mask_padding = hparams.mask_padding
        self.n_mel_channels = hparams.n_mel_channels

        # TODO rename to symbol_embeddings but it will destroy all previous trained models
        symbol_emb_weights = get_symbol_weights(hparams)
        self.embedding = weights_to_embedding(symbol_emb_weights)
        logger.debug(f"is cuda: {self.embedding.weight.is_cuda}")

        speaker_emb_weights = get_speaker_weights(hparams)
        self.speakers_embedding = weights_to_embedding(speaker_emb_weights)

        #self.accent_embedding = nn.Embedding(hparams.n_accents, hparams.accents_embedding_dim)
        # torch.nn.init.xavier_uniform_(self.accent_embedding.weight)

        self.encoder = Encoder(hparams)
        self.decoder = Decoder(hparams, logger)
        self.postnet = Postnet(hparams)
예제 #3
0
def get_symbol_weights(hparams: HParams) -> torch.Tensor:
    model_symbols_count = get_model_symbols_count(
        hparams.n_symbols, hparams.n_accents, hparams.accents_use_own_symbols)

    model_weights = get_uniform_weights(model_symbols_count,
                                        hparams.symbols_embedding_dim)
    return model_weights
    def test_get_model_symbols_count_ns10_na5_returns_10(self):
        res = get_model_symbols_count(n_symbols=10,
                                      n_accents=5,
                                      accents_use_own_symbols=False,
                                      shared_symbol_count=1)

        self.assertEqual(10, res)
    def test_get_model_symbols_count_acc_shared2_ns2_na50_returns_1(self):
        res = get_model_symbols_count(n_symbols=2,
                                      n_accents=50,
                                      accents_use_own_symbols=True,
                                      shared_symbol_count=2)

        self.assertEqual(2, res)
    def test_get_model_symbols_count_acc_shared2_ns10_na5_returns_42(self):
        res = get_model_symbols_count(n_symbols=10,
                                      n_accents=5,
                                      accents_use_own_symbols=True,
                                      shared_symbol_count=2)

        # 2 + 5 accents * 8 symbols
        self.assertEqual(42, res)
    def test_get_model_symbols_count_acc_shared1_ns10_na5_returns_46(self):
        res = get_model_symbols_count(n_symbols=10,
                                      n_accents=5,
                                      accents_use_own_symbols=True,
                                      shared_symbol_count=1)

        # 1 + 5 accents * 9 symbols
        self.assertEqual(46, res)