def test_get_model_symbols_count_uses_const(self): ref = get_model_symbols_count(n_symbols=10, n_accents=5, accents_use_own_symbols=True, shared_symbol_count=SHARED_SYMBOLS_COUNT) res = get_model_symbols_count(n_symbols=10, n_accents=5, accents_use_own_symbols=True) self.assertEqual(ref, res)
def __init__(self, hparams: HParams, logger: Logger): super(Tacotron2, self).__init__() model_symbols_count = get_model_symbols_count( hparams.n_symbols, hparams.n_accents, hparams.accents_use_own_symbols) if hparams.accents_use_own_symbols: logger.info( f"All {hparams.n_accents} accent(s) use an own symbolset. The number of total symbols increased from {hparams.n_symbols} to {model_symbols_count}." ) else: logger.info( f"All {hparams.n_accents} accent(s) share the same symbolset. The number of total symbols is {model_symbols_count}." ) self.logger = logger self.mask_padding = hparams.mask_padding self.n_mel_channels = hparams.n_mel_channels # TODO rename to symbol_embeddings but it will destroy all previous trained models symbol_emb_weights = get_symbol_weights(hparams) self.embedding = weights_to_embedding(symbol_emb_weights) logger.debug(f"is cuda: {self.embedding.weight.is_cuda}") speaker_emb_weights = get_speaker_weights(hparams) self.speakers_embedding = weights_to_embedding(speaker_emb_weights) #self.accent_embedding = nn.Embedding(hparams.n_accents, hparams.accents_embedding_dim) # torch.nn.init.xavier_uniform_(self.accent_embedding.weight) self.encoder = Encoder(hparams) self.decoder = Decoder(hparams, logger) self.postnet = Postnet(hparams)
def get_symbol_weights(hparams: HParams) -> torch.Tensor: model_symbols_count = get_model_symbols_count( hparams.n_symbols, hparams.n_accents, hparams.accents_use_own_symbols) model_weights = get_uniform_weights(model_symbols_count, hparams.symbols_embedding_dim) return model_weights
def test_get_model_symbols_count_ns10_na5_returns_10(self): res = get_model_symbols_count(n_symbols=10, n_accents=5, accents_use_own_symbols=False, shared_symbol_count=1) self.assertEqual(10, res)
def test_get_model_symbols_count_acc_shared2_ns2_na50_returns_1(self): res = get_model_symbols_count(n_symbols=2, n_accents=50, accents_use_own_symbols=True, shared_symbol_count=2) self.assertEqual(2, res)
def test_get_model_symbols_count_acc_shared2_ns10_na5_returns_42(self): res = get_model_symbols_count(n_symbols=10, n_accents=5, accents_use_own_symbols=True, shared_symbol_count=2) # 2 + 5 accents * 8 symbols self.assertEqual(42, res)
def test_get_model_symbols_count_acc_shared1_ns10_na5_returns_46(self): res = get_model_symbols_count(n_symbols=10, n_accents=5, accents_use_own_symbols=True, shared_symbol_count=1) # 1 + 5 accents * 9 symbols self.assertEqual(46, res)