Exemplo n.º 1
0
 def load_tts(self, model_path, model_file, model_config, use_cuda):
     tts_config = os.path.join(model_path, model_config)
     self.model_file = os.path.join(model_path, model_file)
     print(" > Loading TTS model ...")
     print(" | > model config: ", tts_config)
     print(" | > model file: ", model_file)
     self.tts_config = load_config(tts_config)
     self.use_phonemes = self.tts_config.use_phonemes
     self.ap = AudioProcessor(**self.tts_config.audio)
     if self.use_phonemes:
         self.input_size = len(phonemes)
     else:
         self.input_size = len(symbols)
     # load speakers
     if self.config.tts_speakers is not None:
         self.tts_speakers = load_speaker_mapping(
             os.path.join(model_path, self.config.tts_speakers))
         num_speakers = len(self.tts_speakers)
     else:
         num_speakers = 0
     self.tts_model = setup_model(self.input_size,
                                  num_speakers=num_speakers,
                                  c=self.tts_config)
     # load model state
     cp = torch.load(self.model_file)
     # load the model
     self.tts_model.load_state_dict(cp['model'])
     if use_cuda:
         self.tts_model.cuda()
     self.tts_model.eval()
     self.tts_model.decoder.max_decoder_steps = 3000
     if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]:
         self.tts_model.decoder.set_r(cp['r'])
Exemplo n.º 2
0
 def load_tts(self, tts_checkpoint, tts_config, use_cuda):
     print(" > Loading TTS model ...")
     print(" | > model config: ", tts_config)
     print(" | > checkpoint file: ", tts_checkpoint)
     self.tts_config = load_config(tts_config)
     self.use_phonemes = self.tts_config.use_phonemes
     self.ap = AudioProcessor(**self.tts_config.audio)
     if self.use_phonemes:
         self.input_size = len(phonemes)
     else:
         self.input_size = len(symbols)
     # TODO: fix this for multi-speaker model - load speakers
     if self.config.tts_speakers is not None:
         self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
         num_speakers = len(self.tts_speakers)
     else:
         num_speakers = 0
     self.tts_model = setup_model(self.input_size,
                                  num_speakers=num_speakers,
                                  c=self.tts_config)
     # load model state
     cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
     # load the model
     self.tts_model.load_state_dict(cp['model'])
     if use_cuda:
         self.tts_model.cuda()
     self.tts_model.eval()
     self.tts_model.decoder.max_decoder_steps = 3000
     if 'r' in cp:
         self.tts_model.decoder.set_r(cp['r'])
Exemplo n.º 3
0
def format_data(data):
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)

    # setup input data
    text_input = data[0]
    text_lengths = data[1]
    speaker_names = data[2]
    linear_input = data[3] if c.model in ["Tacotron"] else None
    mel_input = data[4]
    mel_lengths = data[5]
    stop_targets = data[6]
    avg_text_length = torch.mean(text_lengths.float())
    avg_spec_length = torch.mean(mel_lengths.float())

    if c.use_speaker_embedding:
        speaker_ids = [
            speaker_mapping[speaker_name] for speaker_name in speaker_names
        ]
        speaker_ids = torch.LongTensor(speaker_ids)
    else:
        speaker_ids = None

    # set stop targets view, we predict a single stop token per iteration.
    stop_targets = stop_targets.view(text_input.shape[0],
                                     stop_targets.size(1) // c.r, -1)
    stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)

    # dispatch data to GPU
    if use_cuda:
        text_input = text_input.cuda(non_blocking=True)
        text_lengths = text_lengths.cuda(non_blocking=True)
        mel_input = mel_input.cuda(non_blocking=True)
        mel_lengths = mel_lengths.cuda(non_blocking=True)
        linear_input = linear_input.cuda(
            non_blocking=True) if c.model in ["Tacotron"] else None
        stop_targets = stop_targets.cuda(non_blocking=True)
        if speaker_ids is not None:
            speaker_ids = speaker_ids.cuda(non_blocking=True)
    return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length
Exemplo n.º 4
0
    def load_tts(self, tts_checkpoint, tts_config, use_cuda):
        global symbols, phonemes

        print(" > Loading TTS model ...")
        print(" | > model config: ", tts_config)
        print(" | > checkpoint file: ", tts_checkpoint)
        self.tts_config = load_config(tts_config)

        if 'text' in self.tts_config.keys():
            symbols, phonemes = make_symbols(**self.tts_config.text)

        self.use_phonemes = self.tts_config.use_phonemes
        self.ap = AudioProcessor(**self.tts_config.audio)
        if self.use_phonemes:
            self.input_size = len(phonemes)
        else:
            self.input_size = len(symbols)
        # load speakers
        if self.config.tts_speakers is not None:
            self.tts_speakers = load_speaker_mapping(
                os.path.join(model_path, self.config.tts_speakers))
            num_speakers = len(self.tts_speakers)
        else:
            num_speakers = 0
        self.tts_model = setup_model(self.input_size,
                                     num_speakers=num_speakers,
                                     c=self.tts_config)
        # load model state
        cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
        # load the model
        self.tts_model.load_state_dict(cp['model'])
        if use_cuda:
            self.tts_model.cuda()
        self.tts_model.eval()
        self.tts_model.decoder.max_decoder_steps = 3000
        if 'r' in cp:
            self.tts_model.decoder.set_r(cp['r'])
Exemplo n.º 5
0
Arquivo: train.py Projeto: zhxfl/TTS
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.train()
    epoch_time = 0
    avg_postnet_loss = 0
    avg_decoder_loss = 0
    avg_stop_loss = 0
    avg_step_time = 0
    avg_loader_time = 0
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    if use_cuda:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    end_time = time.time()
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        speaker_names = data[2]
        linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"
                                              ] else None
        mel_input = data[4]
        mel_lengths = data[5]
        stop_targets = data[6]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())
        loader_time = time.time() - end_time

        if c.use_speaker_embedding:
            speaker_ids = [
                speaker_mapping[speaker_name] for speaker_name in speaker_names
            ]
            speaker_ids = torch.LongTensor(speaker_ids)
        else:
            speaker_ids = None

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) >
                        0.0).unsqueeze(2).float().squeeze(2)

        global_step += 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        if optimizer_st:
            optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(
                non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
                                                  ] else None
            stop_targets = stop_targets.cuda(non_blocking=True)
            if speaker_ids is not None:
                speaker_ids = speaker_ids.cuda(non_blocking=True)

        # forward pass model
        decoder_output, postnet_output, alignments, stop_tokens = model(
            text_input, text_lengths, mel_input, speaker_ids=speaker_ids)

        # loss computation
        stop_loss = criterion_st(stop_tokens,
                                 stop_targets) if c.stopnet else torch.zeros(1)
        if c.loss_masking:
            decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input,
                                         mel_lengths)
            else:
                postnet_loss = criterion(postnet_output, mel_input,
                                         mel_lengths)
        else:
            decoder_loss = criterion(decoder_output, mel_input)
            if c.model in ["Tacotron", "TacotronGST"]:
                postnet_loss = criterion(postnet_output, linear_input)
            else:
                postnet_loss = criterion(postnet_output, mel_input)
        loss = decoder_loss + postnet_loss
        if not c.separate_stopnet and c.stopnet:
            loss += stop_loss

        loss.backward()
        optimizer, current_lr = weight_decay(optimizer, c.wd)
        grad_norm, _ = check_update(model, c.grad_clip)
        optimizer.step()

        # backpass and check the grad norm for stop loss
        if c.separate_stopnet:
            stop_loss.backward()
            optimizer_st, _ = weight_decay(optimizer_st, c.wd)
            grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
            optimizer_st.step()
        else:
            grad_norm_st = 0

        step_time = time.time() - start_time
        epoch_time += step_time

        if global_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  PostnetLoss:{:.5f}  "
                "DecoderLoss:{:.5f}  StopLoss:{:.5f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  "
                "LoaderTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, global_step, loss.item(),
                    postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length,
                    step_time, loader_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
            decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data,
                                      num_gpus) if c.stopnet else stop_loss

        if args.rank == 0:
            avg_postnet_loss += float(postnet_loss.item())
            avg_decoder_loss += float(decoder_loss.item())
            avg_stop_loss += stop_loss if isinstance(
                stop_loss, float) else float(stop_loss.item())
            avg_step_time += step_time
            avg_loader_time += loader_time

            # Plot Training Iter Stats
            # reduce TB load
            if global_step % 10 == 0:
                iter_stats = {
                    "loss_posnet": postnet_loss.item(),
                    "loss_decoder": decoder_loss.item(),
                    "lr": current_lr,
                    "grad_norm": grad_norm,
                    "grad_norm_st": grad_norm_st,
                    "step_time": step_time
                }
                tb_logger.tb_train_iter_stats(global_step, iter_stats)

            if global_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st,
                                    postnet_loss.item(), OUT_PATH, global_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = postnet_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_train_figures(global_step, figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    train_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    train_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_train_audios(global_step,
                                          {'TrainAudio': train_audio},
                                          c.audio["sample_rate"])
        end_time = time.time()

    avg_postnet_loss /= (num_iter + 1)
    avg_decoder_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)
    avg_loader_time /= (num_iter + 1)

    # print epoch stats
    print("   | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
          "AvgPostnetLoss:{:.5f}  AvgDecoderLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}  AvgLoaderTime:{:.2f}".format(
              global_step, avg_total_loss, avg_postnet_loss, avg_decoder_loss,
              avg_stop_loss, epoch_time, avg_step_time, avg_loader_time),
          flush=True)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": avg_postnet_loss,
            "loss_decoder": avg_decoder_loss,
            "stop_loss": avg_stop_loss,
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, global_step)
    return avg_postnet_loss, global_step
Exemplo n.º 6
0
Arquivo: train.py Projeto: zhxfl/TTS
def main(args):  #pylint: disable=redefined-outer-name
    # Audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)

    if c.use_speaker_embedding:
        speakers = get_speakers(c.data_path, c.meta_file_train, c.dataset)
        if args.restore_path:
            prev_out_path = os.path.dirname(args.restore_path)
            speaker_mapping = load_speaker_mapping(prev_out_path)
            assert all([speaker in speaker_mapping
                        for speaker in speakers]), "As of now you, you cannot " \
                                                   "introduce new speakers to " \
                                                   "a previously trained model."
        else:
            speaker_mapping = {name: i for i, name in enumerate(speakers)}
        save_speaker_mapping(OUT_PATH, speaker_mapping)
        num_speakers = len(speaker_mapping)
        print("Training with {} speakers: {}".format(num_speakers,
                                                     ", ".join(speakers)))
    else:
        num_speakers = 0

    model = setup_model(num_chars, num_speakers, c)

    print(" | > Num output units : {}".format(ap.num_freq), flush=True)

    optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
    if c.stopnet and c.separate_stopnet:
        optimizer_st = RAdam(model.decoder.stopnet.parameters(),
                             lr=c.lr,
                             weight_decay=0)
    else:
        optimizer_st = None

    if c.loss_masking:
        criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
                                                  ] else MSELossMasked()
    else:
        criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
                                               ] else nn.MSELoss()
    criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model = model.cuda()
        criterion.cuda()
        if criterion_st:
            criterion_st.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.lr_decay:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        # set gradual training
        if c.gradual_training is not None:
            r, c.batch_size = gradual_training_scheduler(global_step, c)
            c.r = r
            model.decoder.set_r(r)
        print(" > Number of outputs per iteration:", model.decoder.r)

        train_loss, global_step = train(model, criterion, criterion_st,
                                        optimizer, optimizer_st, scheduler, ap,
                                        global_step, epoch)
        val_loss = evaluate(model, criterion, criterion_st, ap, global_step,
                            epoch)
        print(" | > Training Loss: {:.5f}   Validation Loss: {:.5f}".format(
            train_loss, val_loss),
              flush=True)
        target_loss = train_loss
        if c.run_eval:
            target_loss = val_loss
        best_loss = save_best_model(model, optimizer, target_loss, best_loss,
                                    OUT_PATH, global_step, epoch)
Exemplo n.º 7
0
Arquivo: train.py Projeto: zhxfl/TTS
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
    data_loader = setup_loader(ap, is_val=True)
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.eval()
    epoch_time = 0
    avg_postnet_loss = 0
    avg_decoder_loss = 0
    avg_stop_loss = 0
    print("\n > Validation")
    if c.test_sentences_file is None:
        test_sentences = [
            "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
            "Be a voice, not an echo.",
            "I'm sorry Dave. I'm afraid I can't do that.",
            "This cake is great. It's so delicious and moist."
        ]
    else:
        with open(c.test_sentences_file, "r") as f:
            test_sentences = [s.strip() for s in f.readlines()]
    with torch.no_grad():
        if data_loader is not None:
            for num_iter, data in enumerate(data_loader):
                start_time = time.time()

                # setup input data
                text_input = data[0]
                text_lengths = data[1]
                speaker_names = data[2]
                linear_input = data[3] if c.model in [
                    "Tacotron", "TacotronGST"
                ] else None
                mel_input = data[4]
                mel_lengths = data[5]
                stop_targets = data[6]

                if c.use_speaker_embedding:
                    speaker_ids = [
                        speaker_mapping[speaker_name]
                        for speaker_name in speaker_names
                    ]
                    speaker_ids = torch.LongTensor(speaker_ids)
                else:
                    speaker_ids = None

                # set stop targets view, we predict a single stop token per r frames prediction
                stop_targets = stop_targets.view(text_input.shape[0],
                                                 stop_targets.size(1) // c.r,
                                                 -1)
                stop_targets = (stop_targets.sum(2) >
                                0.0).unsqueeze(2).float().squeeze(2)

                # dispatch data to GPU
                if use_cuda:
                    text_input = text_input.cuda()
                    mel_input = mel_input.cuda()
                    mel_lengths = mel_lengths.cuda()
                    linear_input = linear_input.cuda() if c.model in [
                        "Tacotron", "TacotronGST"
                    ] else None
                    stop_targets = stop_targets.cuda()
                    if speaker_ids is not None:
                        speaker_ids = speaker_ids.cuda()

                # forward pass
                decoder_output, postnet_output, alignments, stop_tokens =\
                    model.forward(text_input, text_lengths, mel_input,
                                  speaker_ids=speaker_ids)

                # loss computation
                stop_loss = criterion_st(
                    stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
                if c.loss_masking:
                    decoder_loss = criterion(decoder_output, mel_input,
                                             mel_lengths)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input,
                                                 mel_lengths)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input,
                                                 mel_lengths)
                else:
                    decoder_loss = criterion(decoder_output, mel_input)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input)
                loss = decoder_loss + postnet_loss + stop_loss

                step_time = time.time() - start_time
                epoch_time += step_time

                if num_iter % c.print_step == 0:
                    print(
                        "   | > TotalLoss: {:.5f}   PostnetLoss: {:.5f}   DecoderLoss:{:.5f}  "
                        "StopLoss: {:.5f}  ".format(loss.item(),
                                                    postnet_loss.item(),
                                                    decoder_loss.item(),
                                                    stop_loss.item()),
                        flush=True)

                # aggregate losses from processes
                if num_gpus > 1:
                    postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
                    decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
                    if c.stopnet:
                        stop_loss = reduce_tensor(stop_loss.data, num_gpus)

                avg_postnet_loss += float(postnet_loss.item())
                avg_decoder_loss += float(decoder_loss.item())
                avg_stop_loss += stop_loss.item()

            if args.rank == 0:
                # Diagnostic visualizations
                idx = np.random.randint(mel_input.shape[0])
                const_spec = postnet_output[idx].data.cpu().numpy()
                gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[idx].data.cpu().numpy()
                align_img = alignments[idx].data.cpu().numpy()

                eval_figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_eval_figures(global_step, eval_figures)

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    eval_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    eval_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                         c.audio["sample_rate"])

                # compute average losses
                avg_postnet_loss /= (num_iter + 1)
                avg_decoder_loss /= (num_iter + 1)
                avg_stop_loss /= (num_iter + 1)

                # Plot Validation Stats
                epoch_stats = {
                    "loss_postnet": avg_postnet_loss,
                    "loss_decoder": avg_decoder_loss,
                    "stop_loss": avg_stop_loss
                }
                tb_logger.tb_eval_stats(global_step, epoch_stats)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    style_wav=style_wav)
                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return avg_postnet_loss
Exemplo n.º 8
0
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
    data_loader = setup_loader(ap, model.decoder.r, is_val=True)
    if c.use_speaker_embedding:
        speaker_mapping = load_speaker_mapping(OUT_PATH)
    model.eval()
    epoch_time = 0
    eval_values_dict = {
        'avg_postnet_loss': 0,
        'avg_decoder_loss': 0,
        'avg_stop_loss': 0,
        'avg_align_score': 0
    }
    if c.bidirectional_decoder:
        eval_values_dict['avg_decoder_b_loss'] = 0  # decoder backward loss
        eval_values_dict['avg_decoder_c_loss'] = 0  # decoder consistency loss
    keep_avg = KeepAverage()
    keep_avg.add_values(eval_values_dict)
    print("\n > Validation")

    with torch.no_grad():
        if data_loader is not None:
            for num_iter, data in enumerate(data_loader):
                start_time = time.time()

                # format data
                text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(
                    data)
                assert mel_input.shape[1] % model.decoder.r == 0

                # forward pass model
                if c.bidirectional_decoder:
                    decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
                        text_input,
                        text_lengths,
                        mel_input,
                        speaker_ids=speaker_ids)
                else:
                    decoder_output, postnet_output, alignments, stop_tokens = model(
                        text_input,
                        text_lengths,
                        mel_input,
                        speaker_ids=speaker_ids)

                # loss computation
                stop_loss = criterion_st(
                    stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
                if c.loss_masking:
                    decoder_loss = criterion(decoder_output, mel_input,
                                             mel_lengths)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input,
                                                 mel_lengths)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input,
                                                 mel_lengths)
                else:
                    decoder_loss = criterion(decoder_output, mel_input)
                    if c.model in ["Tacotron", "TacotronGST"]:
                        postnet_loss = criterion(postnet_output, linear_input)
                    else:
                        postnet_loss = criterion(postnet_output, mel_input)
                loss = decoder_loss + postnet_loss + stop_loss

                # backward decoder loss
                if c.bidirectional_decoder:
                    if c.loss_masking:
                        decoder_backward_loss = criterion(
                            torch.flip(decoder_backward_output, dims=(1, )),
                            mel_input, mel_lengths)
                    else:
                        decoder_backward_loss = criterion(
                            torch.flip(decoder_backward_output, dims=(1, )),
                            mel_input)
                    decoder_c_loss = torch.nn.functional.l1_loss(
                        torch.flip(decoder_backward_output, dims=(1, )),
                        decoder_output)
                    loss += decoder_backward_loss + decoder_c_loss
                    keep_avg.update_values({
                        'avg_decoder_b_loss':
                        decoder_backward_loss.item(),
                        'avg_decoder_c_loss':
                        decoder_c_loss.item()
                    })

                step_time = time.time() - start_time
                epoch_time += step_time

                # compute alignment score
                align_score = alignment_diagonal_score(alignments)
                keep_avg.update_value('avg_align_score', align_score)

                # aggregate losses from processes
                if num_gpus > 1:
                    postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
                    decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
                    if c.stopnet:
                        stop_loss = reduce_tensor(stop_loss.data, num_gpus)

                keep_avg.update_values({
                    'avg_postnet_loss':
                    float(postnet_loss.item()),
                    'avg_decoder_loss':
                    float(decoder_loss.item()),
                    'avg_stop_loss':
                    float(stop_loss.item()),
                })

                if num_iter % c.print_step == 0:
                    print(
                        "   | > TotalLoss: {:.5f}   PostnetLoss: {:.5f} - {:.5f}  DecoderLoss:{:.5f} - {:.5f} "
                        "StopLoss: {:.5f} - {:.5f}  AlignScore: {:.4f} : {:.4f}"
                        .format(loss.item(), postnet_loss.item(),
                                keep_avg['avg_postnet_loss'],
                                decoder_loss.item(),
                                keep_avg['avg_decoder_loss'], stop_loss.item(),
                                keep_avg['avg_stop_loss'], align_score,
                                keep_avg['avg_align_score']),
                        flush=True)

            if args.rank == 0:
                # Diagnostic visualizations
                idx = np.random.randint(mel_input.shape[0])
                const_spec = postnet_output[idx].data.cpu().numpy()
                gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
                    "Tacotron", "TacotronGST"
                ] else mel_input[idx].data.cpu().numpy()
                align_img = alignments[idx].data.cpu().numpy()

                eval_figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }

                # Sample audio
                if c.model in ["Tacotron", "TacotronGST"]:
                    eval_audio = ap.inv_spectrogram(const_spec.T)
                else:
                    eval_audio = ap.inv_mel_spectrogram(const_spec.T)
                tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
                                         c.audio["sample_rate"])

                # Plot Validation Stats
                epoch_stats = {
                    "loss_postnet": keep_avg['avg_postnet_loss'],
                    "loss_decoder": keep_avg['avg_decoder_loss'],
                    "stop_loss": keep_avg['avg_stop_loss'],
                    "alignment_score": keep_avg['avg_align_score']
                }

                if c.bidirectional_decoder:
                    epoch_stats['loss_decoder_backward'] = keep_avg[
                        'avg_decoder_b_loss']
                    align_b_img = alignments_backward[idx].data.cpu().numpy()
                    eval_figures['alignment_backward'] = plot_alignment(
                        align_b_img)
                tb_logger.tb_eval_stats(global_step, epoch_stats)
                tb_logger.tb_eval_figures(global_step, eval_figures)

    if args.rank == 0 and epoch > c.test_delay_epochs:
        if c.test_sentences_file is None:
            test_sentences = [
                "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
                "Be a voice, not an echo.",
                "I'm sorry Dave. I'm afraid I can't do that.",
                "This cake is great. It's so delicious and moist."
            ]
        else:
            with open(c.test_sentences_file, "r") as f:
                test_sentences = [s.strip() for s in f.readlines()]

        # test sentences
        test_audios = {}
        test_figures = {}
        print(" | > Synthesizing test sentences")
        speaker_id = 0 if c.use_speaker_embedding else None
        style_wav = c.get("style_wav_for_test")
        for idx, test_sentence in enumerate(test_sentences):
            try:
                wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
                    model,
                    test_sentence,
                    c,
                    use_cuda,
                    ap,
                    speaker_id=speaker_id,
                    style_wav=style_wav)
                file_path = os.path.join(AUDIO_PATH, str(global_step))
                os.makedirs(file_path, exist_ok=True)
                file_path = os.path.join(file_path,
                                         "TestSentence_{}.wav".format(idx))
                ap.save_wav(wav, file_path)
                test_audios['{}-audio'.format(idx)] = wav
                test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
                    postnet_output, ap)
                test_figures['{}-alignment'.format(idx)] = plot_alignment(
                    alignment)
            except:
                print(" !! Error creating Test Sentence -", idx)
                traceback.print_exc()
        tb_logger.tb_test_audios(global_step, test_audios,
                                 c.audio['sample_rate'])
        tb_logger.tb_test_figures(global_step, test_figures)
    return keep_avg['avg_postnet_loss']
Exemplo n.º 9
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train, meta_data_eval, symbols, phonemes
    # Audio processor
    ap = AudioProcessor(**c.audio)
    if 'characters' in c.keys():
        symbols, phonemes = make_symbols(**c.characters)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])
    num_chars = len(phonemes) if c.use_phonemes else len(symbols)

    # load data instances
    meta_data_train, meta_data_eval = load_meta_data(c.datasets)

    # parse speakers
    if c.use_speaker_embedding:
        speakers = get_speakers(meta_data_train)
        if args.restore_path:
            prev_out_path = os.path.dirname(args.restore_path)
            speaker_mapping = load_speaker_mapping(prev_out_path)
            assert all([speaker in speaker_mapping
                        for speaker in speakers]), "As of now you, you cannot " \
                                                   "introduce new speakers to " \
                                                   "a previously trained model."
        else:
            speaker_mapping = {name: i for i, name in enumerate(speakers)}
        save_speaker_mapping(OUT_PATH, speaker_mapping)
        num_speakers = len(speaker_mapping)
        print("Training with {} speakers: {}".format(num_speakers,
                                                     ", ".join(speakers)))
    else:
        num_speakers = 0

    model = setup_model(num_chars, num_speakers, c)

    print(" | > Num output units : {}".format(ap.num_freq), flush=True)

    params = set_weight_decay(model, c.wd)
    optimizer = RAdam(params, lr=c.lr, weight_decay=0)
    if c.stopnet and c.separate_stopnet:
        optimizer_st = RAdam(model.decoder.stopnet.parameters(),
                             lr=c.lr,
                             weight_decay=0)
    else:
        optimizer_st = None

    # setup criterion
    criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)

    if c.noam_schedule:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        # set gradual training
        if c.gradual_training is not None:
            r, c.batch_size = gradual_training_scheduler(global_step, c)
            c.r = r
            model.decoder.set_r(r)
            if c.bidirectional_decoder:
                model.decoder_backward.set_r(r)
            print("\n > Number of output frames:", model.decoder.r)

        train_avg_loss_dict, global_step = train(model, criterion, optimizer,
                                                 optimizer_st, scheduler, ap,
                                                 global_step, epoch)
        eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = train_avg_loss_dict['avg_postnet_loss']
        if c.run_eval:
            target_loss = eval_avg_loss_dict['avg_postnet_loss']
        best_loss = save_best_model(target_loss, best_loss, model, optimizer,
                                    global_step, epoch, c.r, OUT_PATH)
Exemplo n.º 10
0
C.separate_stopnet = True

preprocessor = importlib.import_module('TTS.datasets.preprocess')
preprocessor = getattr(preprocessor, DATASET.lower())

meta_data = preprocessor(DATA_PATH)
loader = setup_loader(C, ap, meta_data)

from TTS.utils.text.symbols import symbols, phonemes
from TTS.utils.generic_utils import sequence_mask
from TTS.layers.losses import L1LossMasked
from TTS.utils.text.symbols import symbols, phonemes

# load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
speaker_mapping = load_speaker_mapping(TTS_MODLE_PATH)
num_speakers = len(speaker_mapping)

model = setup_model(num_chars, num_speakers, C)
checkpoint = torch.load(MODEL_FILE)
model.load_state_dict(checkpoint['model'])
print(checkpoint['step'])
model.eval()
if use_cuda:
    model = model.cuda()

# ### Generate model outputs
import pickle

file_idxs = []
losses = []