Exemple #1
0
def main():
    # Parse Arguments
    parser = argparse.ArgumentParser(description='Train Tacotron TTS')
    parser.add_argument('--force_train',
                        '-f',
                        action='store_true',
                        help='Forces the model to train past total steps')
    parser.add_argument('--force_gta',
                        '-g',
                        action='store_true',
                        help='Force the model to create GTA features')
    parser.add_argument(
        '--force_cpu',
        '-c',
        action='store_true',
        help='Forces CPU-only training, even when in CUDA capable environment')
    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')
    args = parser.parse_args()

    hp.configure(args.hp_file)  # Load hparams from file
    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    force_train = args.force_train
    force_gta = args.force_gta

    if not args.force_cpu and torch.cuda.is_available():
        device = torch.device('cuda')
        for session in hp.tts_schedule:
            _, _, _, batch_size = session
            if batch_size % torch.cuda.device_count() != 0:
                raise ValueError(
                    '`batch_size` must be evenly divisible by n_gpus!')
    else:
        device = torch.device('cpu')
    print('Using device:', device)

    # Instantiate Tacotron Model
    print('\nInitialising Tacotron Model...\n')
    model = Tacotron(embed_dims=hp.tts_embed_dims,
                     num_chars=len(symbols),
                     encoder_dims=hp.tts_encoder_dims,
                     decoder_dims=hp.tts_decoder_dims,
                     n_mels=hp.num_mels,
                     fft_bins=hp.num_mels,
                     postnet_dims=hp.tts_postnet_dims,
                     encoder_K=hp.tts_encoder_K,
                     lstm_dims=hp.tts_lstm_dims,
                     postnet_K=hp.tts_postnet_K,
                     num_highways=hp.tts_num_highways,
                     dropout=hp.tts_dropout,
                     stop_threshold=hp.tts_stop_threshold).to(device)

    optimizer = optim.Adam(model.parameters())
    restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True)

    if not force_gta:
        for i, session in enumerate(hp.tts_schedule):
            current_step = model.get_step()

            r, lr, max_step, batch_size = session

            training_steps = max_step - current_step

            # Do we need to change to the next session?
            if current_step >= max_step:
                # Are there no further sessions than the current one?
                if i == len(hp.tts_schedule) - 1:
                    # There are no more sessions. Check if we force training.
                    if force_train:
                        # Don't finish the loop - train forever
                        training_steps = 999_999_999
                    else:
                        # We have completed training. Breaking is same as continue
                        break
                else:
                    # There is a following session, go to it
                    continue

            model.r = r

            simple_table([('Steps with r=%s' % (repr1(r)),
                           str(training_steps // 1000) + 'k Steps'),
                          ('Batch Size', batch_size), ('Learning Rate', lr),
                          ('Outputs/Step (r)', model.r)])

            train_set, attn_example = get_tts_datasets(paths.data, batch_size,
                                                       r)
            tts_train_loop(paths, model, optimizer, train_set, lr,
                           training_steps, attn_example)

        print('Training Complete.')
        print(
            'To continue training increase tts_total_steps in hparams.py or use --force_train\n'
        )

    print('Creating Ground Truth Aligned Dataset...\n')

    train_set, attn_example = get_tts_datasets(paths.data, 8, model.r)
    create_gta_features(model, train_set, paths.gta)

    print(
        '\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n'
    )
    def train_session(self, model: Tacotron, optimizer: Optimizer,
                      session: TTSSession) -> None:
        current_step = model.get_step()
        training_steps = session.max_step - current_step
        total_iters = len(session.train_set)
        epochs = training_steps // total_iters + 1
        model.r = session.r
        simple_table([(f'Steps with r={session.r}',
                       str(training_steps // 1000) + 'k Steps'),
                      ('Batch Size', session.bs),
                      ('Learning Rate', session.lr),
                      ('Outputs/Step (r)', model.r)])
        for g in optimizer.param_groups:
            g['lr'] = session.lr

        loss_avg = Averager()
        duration_avg = Averager()
        device = next(
            model.parameters()).device  # use same device as model parameters
        for e in range(1, epochs + 1):
            for i, (x, m, ids, x_lens,
                    mel_lens) in enumerate(session.train_set, 1):
                start = time.time()
                model.train()
                x, m = x.to(device), m.to(device)

                m1_hat, m2_hat, attention = model(x, m)

                m1_loss = F.l1_loss(m1_hat, m)
                m2_loss = F.l1_loss(m2_hat, m)
                loss = m1_loss + m2_loss
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hp.tts_clip_grad_norm)
                optimizer.step()
                loss_avg.add(loss.item())
                step = model.get_step()
                k = step // 1000

                duration_avg.add(time.time() - start)
                speed = 1. / duration_avg.get()
                msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \
                      f'| {speed:#.2} steps/s | Step: {k}k | '

                if step % hp.tts_checkpoint_every == 0:
                    ckpt_name = f'taco_step{k}K'
                    save_checkpoint('tts',
                                    self.paths,
                                    model,
                                    optimizer,
                                    name=ckpt_name,
                                    is_silent=True)

                if step % hp.tts_plot_every == 0:
                    self.generate_plots(model, session)

                _, att_score = attention_score(attention, mel_lens)
                att_score = torch.mean(att_score)
                self.writer.add_scalar('Attention_Score/train', att_score,
                                       model.get_step())
                self.writer.add_scalar('Loss/train', loss, model.get_step())
                self.writer.add_scalar('Params/reduction_factor', session.r,
                                       model.get_step())
                self.writer.add_scalar('Params/batch_size', session.bs,
                                       model.get_step())
                self.writer.add_scalar('Params/learning_rate', session.lr,
                                       model.get_step())

                stream(msg)

            val_loss, val_att_score = self.evaluate(model, session.val_set)
            self.writer.add_scalar('Loss/val', val_loss, model.get_step())
            self.writer.add_scalar('Attention_Score/val', val_att_score,
                                   model.get_step())
            save_checkpoint('tts',
                            self.paths,
                            model,
                            optimizer,
                            is_silent=True)

            loss_avg.reset()
            duration_avg.reset()
            print(' ')