def train():
    torch.cuda.set_device(0)
    iteration = 0
    model = WaveRNN(HPARAMS)
    model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=HPARAMS.lr)

    if ARGS.checkpoint:
        if os.path.basename(ARGS.checkpoint).startswith('ema_model'):
            ema_checkpoint = ARGS.checkpoint
        else:
            ema_checkpoint = 'ema_model_' + os.path.basename(ARGS.checkpoint)
            ema_checkpoint = os.path.join(os.path.dirname(ARGS.checkpoint), ema_checkpoint)

        # Initialise EMA from the ema checkpoint.
        logging.info('Initialising ema model {}'.format(ema_checkpoint))
        ema_model = WaveRNN(HPARAMS).cuda()
        ema_base_model, _ = load_checkpoint(ema_checkpoint, ema_model)
        ema = init_ema(ema_base_model, HPARAMS.ema_rate)

        # Initialise vanilla model
        logging.info('Loading checkpoint {}'.format(ARGS.checkpoint))
        model, iteration, optimizer = load_checkpoint(ARGS.checkpoint, model, optimizer)

    else:
        # Initialise EMA from scratch.
        ema = init_ema(model, HPARAMS.ema_rate)

    criterion = nn.NLLLoss(reduction='sum').cuda()
    train_loader, test_loader = get_loader(ARGS.data, 'train', HPARAMS), get_loader(ARGS.data, 'valid', HPARAMS)
    whole_loader = get_loader(ARGS.data, 'valid', HPARAMS, whole=True)
    model = nn.DataParallel(model)

    epoch_offset = max(0, int(iteration / len(train_loader)))
    for _ in range(epoch_offset, ARGS.epochs):
        iteration = train_step(
            train_loader, test_loader, whole_loader, model, optimizer,
            criterion, iteration, ema=ema
        )

        averaged_model = clone_as_averaged_model(model, ema)
        save_checkpoint(
            {
                'state_dict': model.module.state_dict(),
                'iteration': iteration,
                'dataset': ARGS.data,
                'optimizer': optimizer.state_dict(),
            }, iteration,
            'checkpoints/{}/lastmodel.pth'.format(ARGS.expName), ARGS.expName,
        )
        save_checkpoint(
            {
                'state_dict': averaged_model.state_dict(),
                'iteration': iteration,
                'dataset': ARGS.data,
                'optimizer': optimizer.state_dict(),
            }, iteration,
            'checkpoints/{}/ema_model_lastmodel.pth'.format(ARGS.expName), ARGS.expName,
        )
예제 #2
0
def test_subscale_vs_standard_inference_partity():
    hparams = create_hparams()
    model = WaveRNN(hparams, debug=True).cuda()
    seq_len = 100
    m = torch.rand(1, hparams.feat_dims, seq_len).cuda()
    x = torch.rand(1, seq_len * hparams.hop_length).cuda()
    _, _, standard_x = model.inference(m, gt=x)
    _, _, subscale_x = model.subscale_inference(m, gt=x)

    assert (abs(standard_x - subscale_x).mean() < 1e-6)
def clone_as_averaged_model(model, ema):
    averaged_model = WaveRNN(HPARAMS)
    averaged_model.cuda()
    averaged_model.load_state_dict(model.module.state_dict())
    for name, param in averaged_model.named_parameters():
        if name in ema.shadow:
            param.data = ema.shadow[name].clone()
    return averaged_model
예제 #4
0
def test_inference_forward_parity():
    hparams = create_hparams()
    model = WaveRNN(hparams, debug=True).cuda()
    model.train()
    data_path = '../data/short_sens/'
    whole_segments = get_loader(data_path, 'valid', hparams, whole=True)
    for i, (x, m, _) in enumerate(whole_segments):
        x, m = x.cuda(), m.cuda()
        forward_output, f_context, f_x = model.train_mode_generate(x, m)
        inference_output, i_cont_dict, i_x = model.inference(m, gt=x)
        assert (abs(i_x - f_x).mean() < 1e-6)
        '''
coarse_classes, fine_classes = split_signal_PJ(sample)

# In[26]:

plot(coarse_classes[73000:73100])

# In[27]:

plot(fine_classes[73000:73100])

# ### Train Model

# In[28]:

model = WaveRNN().cuda()

# In[29]:

coarse_classes, fine_classes = split_signal(sample)

# In[30]:

batch_size = 128  # 8gb gpu
coarse_classes = coarse_classes[:len(coarse_classes) // batch_size *
                                batch_size]
fine_classes = fine_classes[:len(fine_classes) // batch_size * batch_size]
coarse_classes = np.reshape(coarse_classes, (batch_size, -1))
fine_classes = np.reshape(fine_classes, (batch_size, -1))

# In[31]:
예제 #6
0
    def __init__(self):
        # Parse Arguments
        parser = argparse.ArgumentParser(description='TTS')
        self.args = parser.parse_args()
        self.args.vocoder = 'wavernn'
        self.args.hp_file = 'hparams.py'
        self.args.voc_weights = False
        self.args.tts_weights = False
        self.args.save_attn = False
        self.args.batched = True
        self.args.target = None
        self.args.overlap = None
        self.args.force_cpu = False
        #================ vocoder ================#
        if self.args.vocoder in ['griffinlim', 'gl']:
            self.args.vocoder = 'griffinlim'
        elif self.args.vocoder in ['wavernn', 'wr']:
            self.args.vocoder = 'wavernn'
        else:
            raise argparse.ArgumentError('Must provide a valid vocoder type!')

        hp.configure(self.args.hp_file)  # Load hparams from file

        # set defaults for any arguments that depend on hparams
        if self.args.vocoder == 'wavernn':
            if self.args.target is None:
                self.args.target = hp.voc_target
            if self.args.overlap is None:
                self.args.overlap = hp.voc_overlap
            if self.args.batched is None:
                self.args.batched = hp.voc_gen_batched

        #================ others ================#
        paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
        print("hello")
        print(paths.base)
        if not self.args.force_cpu and torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        print('Using device:', device)

        # === Wavernn === #
        if self.args.vocoder == 'wavernn':
            print('\nInitialising WaveRNN Model...\n')
            self.voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                                     fc_dims=hp.voc_fc_dims,
                                     bits=hp.bits,
                                     pad=hp.voc_pad,
                                     upsample_factors=hp.voc_upsample_factors,
                                     feat_dims=hp.num_mels,
                                     compute_dims=hp.voc_compute_dims,
                                     res_out_dims=hp.voc_res_out_dims,
                                     res_blocks=hp.voc_res_blocks,
                                     hop_length=hp.hop_length,
                                     sample_rate=hp.sample_rate,
                                     mode=hp.voc_mode).to(device)

            voc_load_path = self.args.voc_weights if self.args.voc_weights else paths.voc_latest_weights
            #print(paths.voc_latest_weights)
            self.voc_model.load(voc_load_path)

        # === Tacotron === #
        if hp.tts_model == 'tacotron':
            print('\nInitialising Tacotron Model...\n')
            self.tts_model = Tacotron(
                embed_dims=hp.tts_embed_dims,
                num_chars=len(symbols),
                encoder_dims=hp.tts_encoder_dims,
                decoder_dims=hp.tts_decoder_dims,
                n_mels=hp.num_mels,
                fft_bins=hp.num_mels,
                postnet_dims=hp.tts_postnet_dims,
                encoder_K=hp.tts_encoder_K,
                lstm_dims=hp.tts_lstm_dims,
                postnet_K=hp.tts_postnet_K,
                num_highways=hp.tts_num_highways,
                dropout=hp.tts_dropout,
                stop_threshold=hp.tts_stop_threshold).to(device)

            tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights
            self.tts_model.load(tts_load_path)

        # === Tacotron2 === #
        elif hp.tts_model == 'tacotron2':
            print('\nInitializing Tacotron2 Model...\n')
            self.tts_model = Tacotron2().to(device)
            tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights
            self.tts_model.load(tts_load_path)

        # === Infomation === #
        if hp.tts_model == 'tacotron':
            if self.args.vocoder == 'wavernn':
                voc_k = self.voc_model.get_step() // 1000
                tts_k = self.tts_model.get_step() // 1000

                simple_table([
                    ('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r),
                    ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'),
                    ('Generation Mode',
                     'Batched' if self.args.batched else 'Unbatched'),
                    ('Target Samples',
                     self.args.target if self.args.batched else 'N/A'),
                    ('Overlap Samples',
                     self.args.overlap if self.args.batched else 'N/A')
                ])

            elif self.args.vocoder == 'griffinlim':
                tts_k = self.tts_model.get_step() // 1000
                simple_table([('Tacotron', str(tts_k) + 'k'),
                              ('r', self.tts_model.r),
                              ('Vocoder Type', 'Griffin-Lim'),
                              ('GL Iters', self.args.iters)])

        elif hp.tts_model == 'tacotron2':
            if self.args.vocoder == 'wavernn':
                voc_k = self.voc_model.get_step() // 1000
                tts_k = self.tts_model.get_step() // 1000

                simple_table([
                    ('Tacotron2', str(tts_k) + 'k'),
                    ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'),
                    ('Generation Mode',
                     'Batched' if self.args.batched else 'Unbatched'),
                    ('Target Samples',
                     self.args.target if self.args.batched else 'N/A'),
                    ('Overlap Samples',
                     self.args.overlap if self.args.batched else 'N/A')
                ])

            elif self.args.vocoder == 'griffinlim':
                tts_k = self.tts_model.get_step() // 1000
                simple_table([('Tacotron2', str(tts_k) + 'k'),
                              ('Vocoder Type', 'Griffin-Lim'),
                              ('GL Iters', self.args.iters)])
예제 #7
0
class TaiwaneseTacotron():
    def __init__(self):
        # Parse Arguments
        parser = argparse.ArgumentParser(description='TTS')
        self.args = parser.parse_args()
        self.args.vocoder = 'wavernn'
        self.args.hp_file = 'hparams.py'
        self.args.voc_weights = False
        self.args.tts_weights = False
        self.args.save_attn = False
        self.args.batched = True
        self.args.target = None
        self.args.overlap = None
        self.args.force_cpu = False
        #================ vocoder ================#
        if self.args.vocoder in ['griffinlim', 'gl']:
            self.args.vocoder = 'griffinlim'
        elif self.args.vocoder in ['wavernn', 'wr']:
            self.args.vocoder = 'wavernn'
        else:
            raise argparse.ArgumentError('Must provide a valid vocoder type!')

        hp.configure(self.args.hp_file)  # Load hparams from file

        # set defaults for any arguments that depend on hparams
        if self.args.vocoder == 'wavernn':
            if self.args.target is None:
                self.args.target = hp.voc_target
            if self.args.overlap is None:
                self.args.overlap = hp.voc_overlap
            if self.args.batched is None:
                self.args.batched = hp.voc_gen_batched

        #================ others ================#
        paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
        print("hello")
        print(paths.base)
        if not self.args.force_cpu and torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        print('Using device:', device)

        # === Wavernn === #
        if self.args.vocoder == 'wavernn':
            print('\nInitialising WaveRNN Model...\n')
            self.voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                                     fc_dims=hp.voc_fc_dims,
                                     bits=hp.bits,
                                     pad=hp.voc_pad,
                                     upsample_factors=hp.voc_upsample_factors,
                                     feat_dims=hp.num_mels,
                                     compute_dims=hp.voc_compute_dims,
                                     res_out_dims=hp.voc_res_out_dims,
                                     res_blocks=hp.voc_res_blocks,
                                     hop_length=hp.hop_length,
                                     sample_rate=hp.sample_rate,
                                     mode=hp.voc_mode).to(device)

            voc_load_path = self.args.voc_weights if self.args.voc_weights else paths.voc_latest_weights
            #print(paths.voc_latest_weights)
            self.voc_model.load(voc_load_path)

        # === Tacotron === #
        if hp.tts_model == 'tacotron':
            print('\nInitialising Tacotron Model...\n')
            self.tts_model = Tacotron(
                embed_dims=hp.tts_embed_dims,
                num_chars=len(symbols),
                encoder_dims=hp.tts_encoder_dims,
                decoder_dims=hp.tts_decoder_dims,
                n_mels=hp.num_mels,
                fft_bins=hp.num_mels,
                postnet_dims=hp.tts_postnet_dims,
                encoder_K=hp.tts_encoder_K,
                lstm_dims=hp.tts_lstm_dims,
                postnet_K=hp.tts_postnet_K,
                num_highways=hp.tts_num_highways,
                dropout=hp.tts_dropout,
                stop_threshold=hp.tts_stop_threshold).to(device)

            tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights
            self.tts_model.load(tts_load_path)

        # === Tacotron2 === #
        elif hp.tts_model == 'tacotron2':
            print('\nInitializing Tacotron2 Model...\n')
            self.tts_model = Tacotron2().to(device)
            tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights
            self.tts_model.load(tts_load_path)

        # === Infomation === #
        if hp.tts_model == 'tacotron':
            if self.args.vocoder == 'wavernn':
                voc_k = self.voc_model.get_step() // 1000
                tts_k = self.tts_model.get_step() // 1000

                simple_table([
                    ('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r),
                    ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'),
                    ('Generation Mode',
                     'Batched' if self.args.batched else 'Unbatched'),
                    ('Target Samples',
                     self.args.target if self.args.batched else 'N/A'),
                    ('Overlap Samples',
                     self.args.overlap if self.args.batched else 'N/A')
                ])

            elif self.args.vocoder == 'griffinlim':
                tts_k = self.tts_model.get_step() // 1000
                simple_table([('Tacotron', str(tts_k) + 'k'),
                              ('r', self.tts_model.r),
                              ('Vocoder Type', 'Griffin-Lim'),
                              ('GL Iters', self.args.iters)])

        elif hp.tts_model == 'tacotron2':
            if self.args.vocoder == 'wavernn':
                voc_k = self.voc_model.get_step() // 1000
                tts_k = self.tts_model.get_step() // 1000

                simple_table([
                    ('Tacotron2', str(tts_k) + 'k'),
                    ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'),
                    ('Generation Mode',
                     'Batched' if self.args.batched else 'Unbatched'),
                    ('Target Samples',
                     self.args.target if self.args.batched else 'N/A'),
                    ('Overlap Samples',
                     self.args.overlap if self.args.batched else 'N/A')
                ])

            elif self.args.vocoder == 'griffinlim':
                tts_k = self.tts_model.get_step() // 1000
                simple_table([('Tacotron2', str(tts_k) + 'k'),
                              ('Vocoder Type', 'Griffin-Lim'),
                              ('GL Iters', self.args.iters)])

    def generate(self, 華, input_text):
        inputs = [text_to_sequence(input_text.strip(), ['basic_cleaners'])]
        if hp.tts_model == 'tacotron2':
            self.gen_tacotron2(華, inputs)

        elif hp.tts_model == 'tacotron':
            self.gen_tacotron(華, inputs)

        else:
            print(f"Wrong tts model type {{{tts_model_type}}}")

        print('\n\nDone.\n')

    # custom function
    def gen_tacotron2(self, 華, inputs):
        for i, x in enumerate(inputs, 1):
            print(f'\n| Generating {i}/{len(inputs)}')
            print(x)

            x = np.array(x)[None, :]
            x = torch.autograd.Variable(torch.from_numpy(x)).cuda().long()

            self.tts_model.eval()
            mel_outputs, mel_outputs_postnet, _, alignments = self.tts_model.inference(
                x)
            if self.args.vocoder == 'griffinlim':
                v_type = self.args.vocoder
            elif self.args.vocoder == 'wavernn' and self.args.batched:
                v_type = 'wavernn_batched'
            else:
                v_type = 'wavernn_unbatched'

            # == define output name == #
            if len(華) == 0:
                output_name = re.split(r'\,|\.|\!|\?| ', input_text)[0]
            elif 1 <= len(華) <= 9:
                output_name = 華[:-1]
            elif 9 < len(華):
                output_name = 華[:8]
            print(output_name)
            save_path = "output/{}.wav".format(output_name)
            ##

            if self.args.vocoder == 'wavernn':
                m = mel_outputs_postnet
                self.voc_model.generate(m, save_path, self.args.batched,
                                        hp.voc_target, hp.voc_overlap,
                                        hp.mu_law)

            elif self.args.vocoder == 'griffinlim':
                m = torch.squeeze(mel_outputs_postnet).detach().cpu().numpy()
                wav = reconstruct_waveform(m, n_iter=self.args.iters)
                save_wav(wav, save_path)

    # custom function
    def gen_tacotron(self, 華, inputs):
        for i, x in enumerate(inputs, 1):
            print(f'\n| Generating {i}/{len(inputs)}')
            _, m, attention = self.tts_model.generate(x)
            # Fix mel spectrogram scaling to be from 0 to 1
            m = (m + 4) / 8
            np.clip(m, 0, 1, out=m)

            if self.args.vocoder == 'griffinlim':
                v_type = self.args.vocoder
            elif self.args.vocoder == 'wavernn' and self.args.batched:
                v_type = 'wavernn_batched'
            else:
                v_type = 'wavernn_unbatched'
            # == define output name == #
            if len(華) == 0:
                output_name = re.split(r'\,|\.|\!|\?| ', input_text)[0]
            elif 1 <= len(華) <= 9:
                output_name = 華[:-1]
            elif 9 < len(華):
                output_name = 華[:8]
            print(output_name)
            save_path = "output/{}.wav".format(output_name)
            ##
            if self.args.vocoder == 'wavernn':
                m = torch.tensor(m).unsqueeze(0)
                self.voc_model.generate(m, save_path, self.args.batched,
                                        hp.voc_target, hp.voc_overlap,
                                        hp.mu_law)

            elif self.args.vocoder == 'griffinlim':
                wav = reconstruct_waveform(m, n_iter=self.args.iters)
                save_wav(wav, save_path)
예제 #8
0
def load_model():
    model = WaveRNN(HPARAMS).cuda()
    model, _ = load_checkpoint(ARGS.checkpoint, model)
    return model