예제 #1
0
def get_tacotron_model(model_path):
    device = torch.device('cuda')
    model = Tacotron(embed_dims=hp.tts_embed_dims,
                    num_chars=len(phonemes),
                    encoder_dims=hp.tts_encoder_dims,
                    decoder_dims=hp.tts_decoder_dims,
                    n_mels=hp.num_mels,
                    fft_bins=hp.num_mels,
                    postnet_dims=hp.tts_postnet_dims,
                    encoder_K=hp.tts_encoder_K,
                    lstm_dims=hp.tts_lstm_dims,
                    postnet_K=hp.tts_postnet_K,
                    num_highways=hp.tts_num_highways,
                    dropout=hp.tts_dropout,
                    stop_threshold=hp.tts_stop_threshold).to(device)
    model.load(model_path)
    return model
예제 #2
0
    # Instantiate Tacotron Model
    tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
                         num_chars=len(symbols),
                         encoder_dims=hp.tts_encoder_dims,
                         decoder_dims=hp.tts_decoder_dims,
                         n_mels=hp.num_mels,
                         fft_bins=hp.num_mels,
                         postnet_dims=hp.tts_postnet_dims,
                         encoder_K=hp.tts_encoder_K,
                         lstm_dims=hp.tts_lstm_dims,
                         postnet_K=hp.tts_postnet_K,
                         num_highways=hp.tts_num_highways,
                         dropout=hp.tts_dropout,
                         stop_threshold=hp.tts_stop_threshold).to(device)

    tts_model.load('quick_start/tts_weights/latest_weights.pyt')

    if input_text:
        inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
    else:
        with open('sentences.txt') as f:
            inputs = [
                text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f
            ]

    voc_k = voc_model.get_step() // 1000
    tts_k = tts_model.get_step() // 1000

    r = tts_model.r

    simple_table([('WaveRNN', str(voc_k) + 'k'),
예제 #3
0
    tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
                         num_chars=len(symbols),
                         encoder_dims=hp.tts_encoder_dims,
                         decoder_dims=hp.tts_decoder_dims,
                         n_mels=hp.num_mels,
                         fft_bins=hp.num_mels,
                         postnet_dims=hp.tts_postnet_dims,
                         encoder_K=hp.tts_encoder_K,
                         lstm_dims=hp.tts_lstm_dims,
                         postnet_K=hp.tts_postnet_K,
                         num_highways=hp.tts_num_highways,
                         dropout=hp.tts_dropout,
                         stop_threshold=hp.tts_stop_threshold).to(device)

    tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
    tts_model.load(tts_load_path)

    if input_text:
        inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
    else:
        test_sentences_file = hp.test_sentences_file if hasattr(
            hp, 'test_sentences_file') else 'test_sentences/sentences.txt'
        with open(test_sentences_file) as f:
            inputs = [
                text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f
            ]

    if args.vocoder == 'wavernn':
        voc_k = voc_model.get_step() // 1000
        tts_k = tts_model.get_step() // 1000
예제 #4
0
    # Instantiate Tacotron Model
    tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
                         num_chars=len(symbols),
                         encoder_dims=hp.tts_encoder_dims,
                         decoder_dims=hp.tts_decoder_dims,
                         n_mels=hp.num_mels,
                         fft_bins=hp.num_mels,
                         postnet_dims=hp.tts_postnet_dims,
                         encoder_K=hp.tts_encoder_K,
                         lstm_dims=hp.tts_lstm_dims,
                         postnet_K=hp.tts_postnet_K,
                         num_highways=hp.tts_num_highways,
                         dropout=hp.tts_dropout,
                         stop_threshold=hp.tts_stop_threshold).to(device)

    tts_model.load(f'checkpoints/{checkpoint}')

    if input_text:
        inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
    else:
        with open('phoneme_test.txt') as f:
            inputs = [
                text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f
            ]

    voc_k = voc_model.get_step() // 1000
    tts_k = tts_model.get_step() // 1000

    r = tts_model.r

    simple_table([('WaveRNN', str(voc_k) + 'k'),
예제 #5
0
class TaiwaneseTacotron():
    def __init__(self):
        # Parse Arguments
        parser = argparse.ArgumentParser(description='TTS')
        self.args = parser.parse_args()
        self.args.vocoder = 'wavernn'
        self.args.hp_file = 'hparams.py'
        self.args.voc_weights = False
        self.args.tts_weights = False
        self.args.save_attn = False
        self.args.batched = True
        self.args.target = None
        self.args.overlap = None
        self.args.force_cpu = False
        #================ vocoder ================#
        if self.args.vocoder in ['griffinlim', 'gl']:
            self.args.vocoder = 'griffinlim'
        elif self.args.vocoder in ['wavernn', 'wr']:
            self.args.vocoder = 'wavernn'
        else:
            raise argparse.ArgumentError('Must provide a valid vocoder type!')

        hp.configure(self.args.hp_file)  # Load hparams from file

        # set defaults for any arguments that depend on hparams
        if self.args.vocoder == 'wavernn':
            if self.args.target is None:
                self.args.target = hp.voc_target
            if self.args.overlap is None:
                self.args.overlap = hp.voc_overlap
            if self.args.batched is None:
                self.args.batched = hp.voc_gen_batched

        #================ others ================#
        paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
        print("hello")
        print(paths.base)
        if not self.args.force_cpu and torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        print('Using device:', device)

        # === Wavernn === #
        if self.args.vocoder == 'wavernn':
            print('\nInitialising WaveRNN Model...\n')
            self.voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                                     fc_dims=hp.voc_fc_dims,
                                     bits=hp.bits,
                                     pad=hp.voc_pad,
                                     upsample_factors=hp.voc_upsample_factors,
                                     feat_dims=hp.num_mels,
                                     compute_dims=hp.voc_compute_dims,
                                     res_out_dims=hp.voc_res_out_dims,
                                     res_blocks=hp.voc_res_blocks,
                                     hop_length=hp.hop_length,
                                     sample_rate=hp.sample_rate,
                                     mode=hp.voc_mode).to(device)

            voc_load_path = self.args.voc_weights if self.args.voc_weights else paths.voc_latest_weights
            #print(paths.voc_latest_weights)
            self.voc_model.load(voc_load_path)

        # === Tacotron === #
        if hp.tts_model == 'tacotron':
            print('\nInitialising Tacotron Model...\n')
            self.tts_model = Tacotron(
                embed_dims=hp.tts_embed_dims,
                num_chars=len(symbols),
                encoder_dims=hp.tts_encoder_dims,
                decoder_dims=hp.tts_decoder_dims,
                n_mels=hp.num_mels,
                fft_bins=hp.num_mels,
                postnet_dims=hp.tts_postnet_dims,
                encoder_K=hp.tts_encoder_K,
                lstm_dims=hp.tts_lstm_dims,
                postnet_K=hp.tts_postnet_K,
                num_highways=hp.tts_num_highways,
                dropout=hp.tts_dropout,
                stop_threshold=hp.tts_stop_threshold).to(device)

            tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights
            self.tts_model.load(tts_load_path)

        # === Tacotron2 === #
        elif hp.tts_model == 'tacotron2':
            print('\nInitializing Tacotron2 Model...\n')
            self.tts_model = Tacotron2().to(device)
            tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights
            self.tts_model.load(tts_load_path)

        # === Infomation === #
        if hp.tts_model == 'tacotron':
            if self.args.vocoder == 'wavernn':
                voc_k = self.voc_model.get_step() // 1000
                tts_k = self.tts_model.get_step() // 1000

                simple_table([
                    ('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r),
                    ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'),
                    ('Generation Mode',
                     'Batched' if self.args.batched else 'Unbatched'),
                    ('Target Samples',
                     self.args.target if self.args.batched else 'N/A'),
                    ('Overlap Samples',
                     self.args.overlap if self.args.batched else 'N/A')
                ])

            elif self.args.vocoder == 'griffinlim':
                tts_k = self.tts_model.get_step() // 1000
                simple_table([('Tacotron', str(tts_k) + 'k'),
                              ('r', self.tts_model.r),
                              ('Vocoder Type', 'Griffin-Lim'),
                              ('GL Iters', self.args.iters)])

        elif hp.tts_model == 'tacotron2':
            if self.args.vocoder == 'wavernn':
                voc_k = self.voc_model.get_step() // 1000
                tts_k = self.tts_model.get_step() // 1000

                simple_table([
                    ('Tacotron2', str(tts_k) + 'k'),
                    ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'),
                    ('Generation Mode',
                     'Batched' if self.args.batched else 'Unbatched'),
                    ('Target Samples',
                     self.args.target if self.args.batched else 'N/A'),
                    ('Overlap Samples',
                     self.args.overlap if self.args.batched else 'N/A')
                ])

            elif self.args.vocoder == 'griffinlim':
                tts_k = self.tts_model.get_step() // 1000
                simple_table([('Tacotron2', str(tts_k) + 'k'),
                              ('Vocoder Type', 'Griffin-Lim'),
                              ('GL Iters', self.args.iters)])

    def generate(self, 華, input_text):
        inputs = [text_to_sequence(input_text.strip(), ['basic_cleaners'])]
        if hp.tts_model == 'tacotron2':
            self.gen_tacotron2(華, inputs)

        elif hp.tts_model == 'tacotron':
            self.gen_tacotron(華, inputs)

        else:
            print(f"Wrong tts model type {{{tts_model_type}}}")

        print('\n\nDone.\n')

    # custom function
    def gen_tacotron2(self, 華, inputs):
        for i, x in enumerate(inputs, 1):
            print(f'\n| Generating {i}/{len(inputs)}')
            print(x)

            x = np.array(x)[None, :]
            x = torch.autograd.Variable(torch.from_numpy(x)).cuda().long()

            self.tts_model.eval()
            mel_outputs, mel_outputs_postnet, _, alignments = self.tts_model.inference(
                x)
            if self.args.vocoder == 'griffinlim':
                v_type = self.args.vocoder
            elif self.args.vocoder == 'wavernn' and self.args.batched:
                v_type = 'wavernn_batched'
            else:
                v_type = 'wavernn_unbatched'

            # == define output name == #
            if len(華) == 0:
                output_name = re.split(r'\,|\.|\!|\?| ', input_text)[0]
            elif 1 <= len(華) <= 9:
                output_name = 華[:-1]
            elif 9 < len(華):
                output_name = 華[:8]
            print(output_name)
            save_path = "output/{}.wav".format(output_name)
            ##

            if self.args.vocoder == 'wavernn':
                m = mel_outputs_postnet
                self.voc_model.generate(m, save_path, self.args.batched,
                                        hp.voc_target, hp.voc_overlap,
                                        hp.mu_law)

            elif self.args.vocoder == 'griffinlim':
                m = torch.squeeze(mel_outputs_postnet).detach().cpu().numpy()
                wav = reconstruct_waveform(m, n_iter=self.args.iters)
                save_wav(wav, save_path)

    # custom function
    def gen_tacotron(self, 華, inputs):
        for i, x in enumerate(inputs, 1):
            print(f'\n| Generating {i}/{len(inputs)}')
            _, m, attention = self.tts_model.generate(x)
            # Fix mel spectrogram scaling to be from 0 to 1
            m = (m + 4) / 8
            np.clip(m, 0, 1, out=m)

            if self.args.vocoder == 'griffinlim':
                v_type = self.args.vocoder
            elif self.args.vocoder == 'wavernn' and self.args.batched:
                v_type = 'wavernn_batched'
            else:
                v_type = 'wavernn_unbatched'
            # == define output name == #
            if len(華) == 0:
                output_name = re.split(r'\,|\.|\!|\?| ', input_text)[0]
            elif 1 <= len(華) <= 9:
                output_name = 華[:-1]
            elif 9 < len(華):
                output_name = 華[:8]
            print(output_name)
            save_path = "output/{}.wav".format(output_name)
            ##
            if self.args.vocoder == 'wavernn':
                m = torch.tensor(m).unsqueeze(0)
                self.voc_model.generate(m, save_path, self.args.batched,
                                        hp.voc_target, hp.voc_overlap,
                                        hp.mu_law)

            elif self.args.vocoder == 'griffinlim':
                wav = reconstruct_waveform(m, n_iter=self.args.iters)
                save_wav(wav, save_path)
예제 #6
0
def thak():
    class Tshamsoo():
        force_cpu = os.getenv('FORCE_CPU', False)
        hp_file = 'hparams.py'
        vocoder = os.getenv('VOCODER', 'wavernn')
        batched = os.getenv('BATCHED', True)
        target = os.getenv('TARGET', None)
        overlap = os.getenv('OVERLAP', None)
        tts_weights = None
        save_attn = os.getenv('SAVE_ATTN', False)
        voc_weights = None
        iters = os.getenv('GL_ITERS', 32)

    args = Tshamsoo()
    if args.vocoder in ['griffinlim', 'gl']:
        args.vocoder = 'griffinlim'
    elif args.vocoder in ['wavernn', 'wr']:
        args.vocoder = 'wavernn'
    else:
        raise argparse.ArgumentError('Must provide a valid vocoder type!')

    hp.configure(args.hp_file)  # Load hparams from file

    tts_weights = args.tts_weights
    save_attn = args.save_attn

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    if args.vocoder == 'wavernn':
        # set defaults for any arguments that depend on hparams
        if args.target is None:
            args.target = hp.voc_target
        if args.overlap is None:
            args.overlap = hp.voc_overlap
        if args.batched is None:
            args.batched = hp.voc_gen_batched

        batched = args.batched
        target = int(args.target)
        overlap = int(args.overlap)

        print('\nInitialising WaveRNN Model...\n')
        # Instantiate WaveRNN Model
        voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                            fc_dims=hp.voc_fc_dims,
                            bits=hp.bits,
                            pad=hp.voc_pad,
                            upsample_factors=hp.voc_upsample_factors,
                            feat_dims=hp.num_mels,
                            compute_dims=hp.voc_compute_dims,
                            res_out_dims=hp.voc_res_out_dims,
                            res_blocks=hp.voc_res_blocks,
                            hop_length=hp.hop_length,
                            sample_rate=hp.sample_rate,
                            mode=hp.voc_mode).to(device)

        voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
        voc_model.load(voc_load_path)
    else:
        voc_model = None
        batched = None
        target = None
        overlap = None

    print('\nInitialising Tacotron Model...\n')

    # Instantiate Tacotron Model
    tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
                         num_chars=len(symbols),
                         encoder_dims=hp.tts_encoder_dims,
                         decoder_dims=hp.tts_decoder_dims,
                         n_mels=hp.num_mels,
                         fft_bins=hp.num_mels,
                         postnet_dims=hp.tts_postnet_dims,
                         encoder_K=hp.tts_encoder_K,
                         lstm_dims=hp.tts_lstm_dims,
                         postnet_K=hp.tts_postnet_K,
                         num_highways=hp.tts_num_highways,
                         dropout=hp.tts_dropout,
                         stop_threshold=hp.tts_stop_threshold).to(device)

    tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
    tts_model.load(tts_load_path)
    return args, voc_model, tts_model, batched, target, overlap, save_attn
예제 #7
0
    # Instantiate Tacotron Model
    tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
                         num_chars=len(symbols),
                         encoder_dims=hp.tts_encoder_dims,
                         decoder_dims=hp.tts_decoder_dims,
                         n_mels=hp.num_mels,
                         fft_bins=hp.num_mels,
                         postnet_dims=hp.tts_postnet_dims,
                         encoder_K=hp.tts_encoder_K,
                         lstm_dims=hp.tts_lstm_dims,
                         postnet_K=hp.tts_postnet_K,
                         num_highways=hp.tts_num_highways,
                         dropout=hp.tts_dropout,
                         stop_threshold=hp.tts_stop_threshold).to(device)

    tts_model.load('quick_start/tts_weights/blizz.tacotron.JE.200k.pyt')

    if input_text:
        inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
    else:
        with open('phonemes.txt') as f:
            inputs = [
                text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f
            ]

    voc_k = voc_model.get_step() // 1000
    tts_k = tts_model.get_step() // 1000

    r = tts_model.r

    simple_table([('WaveRNN', str(voc_k) + 'k'),
예제 #8
0
def main():
    # Parse Arguments
    parser = argparse.ArgumentParser(description='Train Tacotron TTS')
    parser.add_argument('--force_train',
                        '-f',
                        action='store_true',
                        help='Forces the model to train past total steps')
    parser.add_argument('--force_gta',
                        '-g',
                        action='store_true',
                        help='Force the model to create GTA features')
    parser.add_argument('--force_attn',
                        '-a',
                        action='store_true',
                        help='Force the model to create attn_ref')
    parser.add_argument(
        '--force_cpu',
        '-c',
        action='store_true',
        help='Forces CPU-only training, even when in CUDA capable environment')
    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')
    args = parser.parse_args()

    hp.configure(args.hp_file)  # Load hparams from file
    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    if hasattr(hp, 'random_seed'):
        set_global_seeds(hp.random_seed)

    force_train = args.force_train
    force_gta = args.force_gta
    force_attn = args.force_attn

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
        for session in hp.tts_schedule:
            _, _, _, batch_size = session
            if batch_size % torch.cuda.device_count() != 0:
                raise ValueError(
                    '`batch_size` must be evenly divisible by n_gpus!')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    # Instantiate Tacotron Model
    print('\nInitialising Tacotron Model...\n')
    model = Tacotron(embed_dims=hp.tts_embed_dims,
                     num_chars=len(symbols),
                     encoder_dims=hp.tts_encoder_dims,
                     decoder_dims=hp.tts_decoder_dims,
                     n_mels=hp.num_mels,
                     fft_bins=hp.num_mels,
                     postnet_dims=hp.tts_postnet_dims,
                     encoder_K=hp.tts_encoder_K,
                     lstm_dims=hp.tts_lstm_dims,
                     postnet_K=hp.tts_postnet_K,
                     num_highways=hp.tts_num_highways,
                     dropout=hp.tts_dropout,
                     stop_threshold=hp.tts_stop_threshold,
                     mode=hp.mode).to(device)

    optimizer = optim.Adam(model.parameters())
    restore_checkpoint('tts',
                       paths,
                       model,
                       optimizer,
                       create_if_missing=True,
                       init_weights_path=hp.tts_init_weights_path)

    # import pdb; pdb.set_trace()

    if hp.mode != 'attention_forcing_online':
        model_tf = None
    else:
        model_tf = Tacotron(embed_dims=hp.tts_embed_dims,
                            num_chars=len(symbols),
                            encoder_dims=hp.tts_encoder_dims,
                            decoder_dims=hp.tts_decoder_dims,
                            n_mels=hp.num_mels,
                            fft_bins=hp.num_mels,
                            postnet_dims=hp.tts_postnet_dims,
                            encoder_K=hp.tts_encoder_K,
                            lstm_dims=hp.tts_lstm_dims,
                            postnet_K=hp.tts_postnet_K,
                            num_highways=hp.tts_num_highways,
                            dropout=hp.tts_dropout,
                            stop_threshold=hp.tts_stop_threshold,
                            mode='teacher_forcing').to(device)
        model_tf.load(hp.model_tf_path)

        # pdb.set_trace()

    if not (force_gta or force_attn):
        for i, session in enumerate(hp.tts_schedule):
            current_step = model.get_step()

            r, lr, max_step, batch_size = session

            training_steps = max_step - current_step

            # Do we need to change to the next session?
            if current_step >= max_step:
                # Are there no further sessions than the current one?
                if i == len(hp.tts_schedule) - 1:
                    # There are no more sessions. Check if we force training.
                    if force_train:
                        # Don't finish the loop - train forever
                        training_steps = 999_999_999
                    else:
                        # We have completed training. Breaking is same as continue
                        break
                else:
                    # There is a following session, go to it
                    continue

            model.r = r
            # if model_tf is not None: model_tf.r = r

            simple_table([(f'Steps with r={r}',
                           str(training_steps // 1000) + 'k Steps'),
                          ('Batch Size', batch_size), ('Learning Rate', lr),
                          ('Outputs/Step (r)', model.r)])

            train_set, attn_example = get_tts_datasets(paths.data, batch_size,
                                                       r)
            tts_train_loop(paths,
                           model,
                           optimizer,
                           train_set,
                           lr,
                           training_steps,
                           attn_example,
                           hp=hp,
                           model_tf=model_tf)

        print('Training Complete.')
        print(
            'To continue training increase tts_total_steps in hparams.py or use --force_train\n'
        )

    train_set, attn_example = get_tts_datasets(paths.data, 8, model.r)
    if force_gta:
        print(
            f'Creating Ground Truth Aligned Dataset at {paths.gta_model}...\n')
        create_gta_features(model, train_set, paths.gta_model)
    elif force_attn:
        print(f'Creating Reference Attention at {paths.attn_model}...\n')
        create_attn_ref(model, train_set, paths.attn_model)

    print(
        '\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n'
    )