Ejemplo n.º 1
0
    def setup(self):
        # load configs
        self.TTS_CONFIG = load_config(self.TTS_CONFIG)
        self.VOCODER_CONFIG = load_config(self.VOCODER_CONFIG)

        # load the audio processor
        self.ap = AudioProcessor(**self.TTS_CONFIG.audio)

        # load the model
        num_chars = len(phonemes) if self.TTS_CONFIG.use_phonemes else len(
            symbols)
        self.model = setup_model(num_chars, len(self.speakers),
                                 self.TTS_CONFIG)

        self.model, _ = load_checkpoint(self.model,
                                        self.TTS_MODEL,
                                        use_cuda=self.use_cuda)
        self.model.eval()

        # LOAD VOCODER MODEL
        self.vocoder_model = setup_generator(self.VOCODER_CONFIG)
        self.vocoder_model, _ = load_vocoder_checkpoint(
            self.vocoder_model, checkpoint_path=self.VOCODER_MODEL)
        self.vocoder_model.remove_weight_norm()
        self.vocoder_model.inference_padding = 0

        self.ap_vocoder = AudioProcessor(**self.VOCODER_CONFIG['audio'])
        if self.use_cuda:
            self.vocoder_model.cuda()
        self.vocoder_model.eval()
Ejemplo n.º 2
0
def load_vocoder(use_cuda):
    """
    Loads the Vocoder model

    Parameters
    ----------
    use_cuda : bool
        whether to use the gpu

    Returns
    -------
    model, audio processor, model config
    """
    VOCODER_MODEL = model_path / 'vocoder_model.pth.tar'
    VOCODER_CONFIG = model_path / 'vocoder_config.json'

    VOCODER_CONFIG = load_config(VOCODER_CONFIG)
    VOCODER_CONFIG.audio['stats_path'] = str(model_path /
                                             'vocoder_scale_stats.npy')

    ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio'])

    vocoder_model = setup_generator(VOCODER_CONFIG)
    cp = torch.load(VOCODER_MODEL, map_location=torch.device('cpu'))
    vocoder_model.load_state_dict(cp['model'])
    vocoder_model.remove_weight_norm()
    vocoder_model.inference_padding = 0

    if use_cuda:
        vocoder_model.cuda()
    vocoder_model.eval()

    return vocoder_model, ap_vocoder, VOCODER_CONFIG
Ejemplo n.º 3
0
 def load_vocoder(self, model_file, model_config, use_cuda):
     self.vocoder_config = load_config(model_config)
     self.vocoder_ap = AudioProcessor(**self.vocoder_config['audio'])
     self.vocoder_model = setup_generator(self.vocoder_config)
     self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
     if use_cuda:
         self.vocoder_model.cuda()
def setup():
    use_cuda = True

    # model paths
    TTS_MODEL = "tts_model.pth.tar"
    TTS_CONFIG = "config.json"
    VOCODER_MODEL = "vocoder_model.pth.tar"
    VOCODER_CONFIG = "config_vocoder.json"

    # Load configs
    TTS_CONFIG = load_config(TTS_CONFIG)
    VOCODER_CONFIG = load_config(VOCODER_CONFIG)

    ap = AudioProcessor(**TTS_CONFIG.audio)


    # LOAD TTS MODEL
    # multi speaker 
    speaker_id = None
    speakers = []

    # load the model
    num_chars = len(phonemes) if TTS_CONFIG.use_phonemes else len(symbols)
    model = setup_model(num_chars, len(speakers), TTS_CONFIG)

    # load model state
    cp =  torch.load(TTS_MODEL, map_location=torch.device('cpu'))

    # load the model
    model.load_state_dict(cp['model'])
    if use_cuda:
        model.cuda()
    model.eval()

    # set model stepsize
    if 'r' in cp:
        model.decoder.set_r(cp['r'])


    from TTS.vocoder.utils.generic_utils import setup_generator

    # LOAD VOCODER MODEL
    vocoder_model = setup_generator(VOCODER_CONFIG)
    vocoder_model.load_state_dict(torch.load(VOCODER_MODEL, map_location="cpu")["model"])
    vocoder_model.remove_weight_norm()
    vocoder_model.inference_padding = 0

    ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio'])    
    if use_cuda:
        vocoder_model.cuda()
    vocoder_model.eval()

    return model, vocoder_model, speaker_id, TTS_CONFIG, use_cuda, ap
Ejemplo n.º 5
0
    def load_vocoder(self, model_file, model_config, use_cuda):
        self.vocoder_config = load_config(model_config)
        self.vocoder_model = setup_generator(self.vocoder_config)
        self.vocoder_model.load_state_dict(
            torch.load(model_file, map_location="cpu")["model"])
        self.vocoder_model.remove_weight_norm()
        self.vocoder_model.inference_padding = 0
        self.vocoder_config = load_config(model_config)

        if use_cuda:
            self.vocoder_model.cuda()
        self.vocoder_model.eval()
Ejemplo n.º 6
0
def load_vocoder(lib_path, model_file, model_config, use_cuda):
    sys.path.append(lib_path) # set this if ParallelWaveGAN is not installed globally
    #pylint: disable=import-outside-toplevel
    vocoder_config = load_config(model_config)
    vocoder_model = setup_generator(vocoder_config)
    checkpoint = torch.load(model_file, map_location='cpu')
    print(' > Model step:', checkpoint['step'])
    vocoder_model.load_state_dict(checkpoint['model'])
    vocoder_model.remove_weight_norm()
    vocoder_model.inference_padding = 0
    vocoder_config = load_config(model_config)
    ap_vocoder = AudioProcessor(**vocoder_config['audio'])

    if use_cuda:
        vocoder_model.cuda()
    return vocoder_model.eval(), ap_vocoder
Ejemplo n.º 7
0
    def _load_vocoder(self, model_file: str, model_config: str,
                      use_cuda: bool) -> None:
        """Load the vocoder model.

        Args:
            model_file (str): path to the model checkpoint.
            model_config (str): path to the model config file.
            use_cuda (bool): enable/disable CUDA use.
        """
        self.vocoder_config = load_config(model_config)
        self.vocoder_ap = AudioProcessor(verbose=False,
                                         **self.vocoder_config["audio"])
        self.vocoder_model = setup_generator(self.vocoder_config)
        self.vocoder_model.load_checkpoint(self.vocoder_config,
                                           model_file,
                                           eval=True)
        if use_cuda:
            self.vocoder_model.cuda()
Ejemplo n.º 8
0
    def __init__(self, use_cuda=False, verbose=False):
        self.use_cuda = use_cuda
        self.verbose = verbose

        # load configs
        self.TTS_CONFIG = load_config(TTS_CONFIG)
        self.VOCODER_CONFIG = load_config(VOCODER_CONFIG)

        # load the audio processor
        self.ap = AudioProcessor(**self.TTS_CONFIG.audio)

        # LOAD TTS MODEL
        self.speaker_id = None
        speakers = []

        # load the model
        num_chars = len(phonemes) if self.TTS_CONFIG.use_phonemes else len(symbols)
        self.model = setup_model(num_chars, len(speakers), self.TTS_CONFIG)

        # load model state
        cp = torch.load(TTS_MODEL, map_location=torch.device('cpu'))

        # load the model
        self.model.load_state_dict(cp['model'])
        if self.use_cuda:
            self.model.cuda()
        self.model.eval()

        # set model stepsize
        if 'r' in cp:
            self.model.decoder.set_r(cp['r'])

        # LOAD VOCODER MODEL
        self.vocoder_model = setup_generator(self.VOCODER_CONFIG)
        self.vocoder_model.load_state_dict(torch.load(VOCODER_MODEL, map_location="cpu")["model"])
        self.vocoder_model.remove_weight_norm()
        self.vocoder_model.inference_padding = 0

        ap_vocoder = AudioProcessor(**self.VOCODER_CONFIG['audio'])
        if self.use_cuda:
            self.vocoder_model.cuda()

        self.vocoder_model.eval()
Ejemplo n.º 9
0
 def __init__(self, TTS_MODEL, TTS_CONFIG, VOCODER_MODEL, VOCODER_CONFIG,
              use_cuda, use_gl):
     self.use_cuda = use_cuda
     self.use_gl = use_gl
     # model paths
     self.tts_config = load_config(TTS_CONFIG)
     vocoder_config = load_config(VOCODER_CONFIG)
     # load audio processor
     self.ap = AudioProcessor(**self.tts_config.audio)
     # LOAD TTS MODEL
     # multi speaker
     self.speaker_id = None
     speakers = []
     # load the model
     num_chars = len(phonemes) if self.tts_config.use_phonemes else len(
         symbols)
     self.model = setup_model(num_chars, len(speakers), self.tts_config)
     # load model state
     self.cp = torch.load(TTS_MODEL, map_location=torch.device('cpu'))
     # load the model
     self.model.load_state_dict(self.cp['model'])
     if self.use_cuda:
         self.model.cuda()
     self.model.train(False)
     self.model.eval()
     # set model stepsize
     if 'r' in self.cp:
         self.model.decoder.set_r(self.cp['r'])
     # LOAD VOCODER MODEL
     self.vocoder_model = setup_generator(vocoder_config)
     self.vocoder_model.load_state_dict(
         torch.load(VOCODER_MODEL, map_location="cpu")["model"])
     self.vocoder_model.remove_weight_norm()
     self.vocoder_model.inference_padding = 0
     #ap_vocoder = AudioProcessor(**vocoder_config['audio'])
     if use_cuda:
         self.vocoder_model.cuda()
     self.vocoder_model.train(False)
     self.vocoder_model.eval()
     #get sample rate
     self.sample_rate = self.ap.sample_rate
     gc.collect(2)
Ejemplo n.º 10
0
    def __init__(self):
        # load the audio processor
        self.audio_processor = AudioProcessor(**TTS_CONFIG.audio)

        # LOAD TTS MODEL
        # multi speaker
        speakers = []
        speaker_id = None

        # load the model
        num_chars = len(phonemes) if TTS_CONFIG.use_phonemes else len(symbols)
        self.model = setup_model(num_chars, len(speakers), TTS_CONFIG)

        # load model state
        self.model, _ = load_checkpoint(self.model,
                                        TTS_MODEL,
                                        use_cuda=USE_CUDA)
        self.model.eval()
        self.model.store_inverse()

        # LOAD VOCODER MODEL
        self.vocoder_model = setup_generator(VOCODER_CONFIG)
        self.vocoder_model.load_state_dict(
            torch.load(VOCODER_MODEL, map_location="cpu")["model"])
        self.vocoder_model.remove_weight_norm()
        self.vocoder_model.inference_padding = 0

        # scale factor for sampling rate difference
        self.scale_factor = [
            1, VOCODER_CONFIG['audio']['sample_rate'] /
            self.audio_processor.sample_rate
        ]
        print(f"scale_factor: {self.scale_factor}")

        self.ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio'])
        if USE_CUDA:
            self.vocoder_model.cuda()
        self.vocoder_model.eval()
# load model state
cp = torch.load(TTS_MODEL, map_location=torch.device('cpu'))

# load the model
model.load_state_dict(cp['model'])
if use_cuda:
    model.cuda()
model.eval()

# set model stepsize
if 'r' in cp:
    model.decoder.set_r(cp['r'])

# LOAD VOCODER MODEL
vocoder_model = setup_generator(VOCODER_CONFIG)
vocoder_model.load_state_dict(
    torch.load(VOCODER_MODEL, map_location="cpu")["model"])
vocoder_model.remove_weight_norm()
vocoder_model.inference_padding = 0

ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio'])
if use_cuda:
    vocoder_model.cuda()
vocoder_model.eval()

playsound('./wavs/DoorCracking.wav')
playsound('./wavs/on.wav')
text2speech("hey!")

mocking = False
Ejemplo n.º 12
0
    def __init__(self, text, expected_output_audio_format, file_name):
        # set a pysbd segmenter to be used later to divide the input into segments
        self.seg = pysbd.Segmenter(language="en", clean=True)
        # runtime settings
        use_cuda = False

        # model paths - models and config files are taken from Mozilla TTS's github page
        TTS_MODEL = "/path/to/checkpoint_130000.pth.tar"
        TTS_CONFIG = "server/config/config.json"
        VOCODER_MODEL = "/path/to/checkpoint_1450000.pth.tar"
        VOCODER_CONFIG = "server/config/config_vocoder.json"

        # load configs
        TTS_CONFIG = load_config(TTS_CONFIG)
        self.TTS_CONFIG = TTS_CONFIG  # set it as a class variable to be later used by convert_audio_to()
        VOCODER_CONFIG = load_config(VOCODER_CONFIG)

        # load the audio processor
        ap = AudioProcessor(**TTS_CONFIG.audio)

        # LOAD TTS MODEL
        # multi speaker
        self.speaker_id = None
        self.speakers = []

        # use the imported symbols and phonemes
        global symbols, phonemes

        use_phonemes = TTS_CONFIG.use_phonemes

        if 'characters' in TTS_CONFIG.keys():
            symbols, phonemes = make_symbols(**TTS_CONFIG.characters)

        if use_phonemes:
            num_chars = len(phonemes)
        else:
            num_chars = len(symbols)

        # load the model
        model = setup_model(num_chars, len(self.speakers), TTS_CONFIG)

        # load model state
        cp = torch.load(TTS_MODEL, map_location=torch.device('cpu'))

        # load the model
        model.load_state_dict(cp['model'])
        if use_cuda:
            model.cuda()
        model.eval()

        model.decoder.max_decoder_steps = 3000

        # set model stepsize
        if 'r' in cp:
            model.decoder.set_r(cp['r'])

        # # LOAD VOCODER MODEL
        self.vocoder_model = setup_generator(VOCODER_CONFIG)
        self.vocoder_model.load_state_dict(
            torch.load(VOCODER_MODEL, map_location="cpu")["model"])
        self.vocoder_model.remove_weight_norm()
        self.vocoder_model.inference_padding = 0

        # ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio'])
        if use_cuda:
            self.vocoder_model.cuda()
        self.vocoder_model.eval()

        # TODO: need to train a model?
        wav = self.tts(model,
                       text,
                       TTS_CONFIG,
                       use_cuda,
                       ap,
                       use_gl=False,
                       figures=True)
        print(len(wav.tobytes()))

        # save the generated .wav file as (file_name + "_audio.wav")
        wavfile.write(file_name + "_audio.wav",
                      TTS_CONFIG.audio["sample_rate"], wav)

        # convert the generated audio file to the specifed audio format
        self.convert_audio_to(expected_output_audio_format, file_name)
Ejemplo n.º 13
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup optimizers
    optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
    optimizer_disc = RAdam(model_disc.parameters(),
                           lr=c.lr_disc,
                           weight_decay=0)

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if 'lr_scheduler_gen' in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if 'lr_scheduler_disc' in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if args.restore_path:
        print(f" > Restoring from {os.path.basename(args.restore_path)}...")
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint['model'])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint['optimizer'])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint['model_disc'])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
            if 'scheduler' in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler_gen.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler_gen.optimizer = optimizer_gen
            if 'scheduler_disc' in checkpoint:
                print(" > Restoring Discriminator LR Scheduler...")
                scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
                scheduler_disc.optimizer = optimizer_disc
        except RuntimeError:
            # restore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer_gen.param_groups:
            group['lr'] = c.lr_gen

        for group in optimizer_disc.param_groups:
            group['lr'] = c.lr_disc

        print(f" > Model restored from step {checkpoint['step']:d}",
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model_gen = DDP_th(model_gen, device_ids=[args.rank])
        model_disc = DDP_th(model_disc, device_ids=[args.rank])

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if args.restore_step == 0 or not args.best_path:
        best_loss = float('inf')
        print(" > Starting with inf best loss.")
    else:
        print(" > Restoring best loss from "
              f"{os.path.basename(args.best_path)} ...")
        best_loss = torch.load(args.best_path,
                               map_location='cpu')['model_loss']
        print(f" > Starting with best loss of {best_loss}.")
    keep_all_best = c.get('keep_all_best', False)
    keep_after = c.get('keep_after', 10000)  # void if keep_all_best False

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_gen, criterion_gen, optimizer_gen,
                               model_disc, criterion_disc, optimizer_disc,
                               scheduler_gen, scheduler_disc, ap, global_step,
                               epoch)
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_gen,
            optimizer_gen,
            scheduler_gen,
            model_disc,
            optimizer_disc,
            scheduler_disc,
            global_step,
            epoch,
            OUT_PATH,
            keep_all_best=keep_all_best,
            keep_after=keep_after,
            model_losses=eval_avg_loss_dict,
        )
Ejemplo n.º 14
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # print(f" > Loading wavs from: {c.data_path}")
    # if c.feature_path is not None:
    #     print(f" > Loading features from: {c.feature_path}")
    #     eval_data, train_data = load_wav_feat_data(
    #         c.data_path, c.feature_path, c.eval_split_size
    #     )
    # else:
    #     mel_feat_path = os.path.join(OUT_PATH, "mel")
    #     feat_data = find_feat_files(mel_feat_path)
    #     if feat_data:
    #         print(f" > Loading features from: {mel_feat_path}")
    #         eval_data, train_data = load_wav_feat_data(
    #             c.data_path, mel_feat_path, c.eval_split_size
    #         )
    #     else:
    #         print(" > No feature data found. Preprocessing...")
    #         # preprocessing feature data from given wav files
    #         preprocess_wav_files(OUT_PATH, CONFIG, ap)
    #         eval_data, train_data = load_wav_feat_data(
    #             c.data_path, mel_feat_path, c.eval_split_size
    #         )

    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
    # setup model
    model_wavernn = setup_generator(c)

    # setup amp scaler
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

    # define train functions
    if c.mode == "mold":
        criterion = discretized_mix_logistic_loss
    elif c.mode == "gauss":
        criterion = gaussian_loss
    elif isinstance(c.mode, int):
        criterion = torch.nn.CrossEntropyLoss()

    if use_cuda:
        model_wavernn.cuda()
        if isinstance(c.mode, int):
            criterion.cuda()

    optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)

    scheduler = None
    if "lr_scheduler" in c:
        scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
        scheduler = scheduler(optimizer, **c.lr_scheduler_params)
    # slow start for the first 5 epochs
    # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
    # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # restore any checkpoint
    if args.restore_path:
        print(f" > Restoring from {os.path.basename(args.restore_path)}...")
        checkpoint = torch.load(args.restore_path, map_location="cpu")
        try:
            print(" > Restoring Model...")
            model_wavernn.load_state_dict(checkpoint["model"])
            print(" > Restoring Optimizer...")
            optimizer.load_state_dict(checkpoint["optimizer"])
            if "scheduler" in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler.load_state_dict(checkpoint["scheduler"])
                scheduler.optimizer = optimizer
            if "scaler" in checkpoint and c.mixed_precision:
                print(" > Restoring AMP Scaler...")
                scaler.load_state_dict(checkpoint["scaler"])
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_wavernn.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model_wavernn.load_state_dict(model_dict)

        print(" > Model restored from step %d" % checkpoint["step"], flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    # DISTRIBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_parameters = count_parameters(model_wavernn)
    print(" > Model has {} parameters".format(num_parameters), flush=True)

    if args.restore_step == 0 or not args.best_path:
        best_loss = float("inf")
        print(" > Starting with inf best loss.")
    else:
        print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
        best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
        print(f" > Starting with loaded last best loss {best_loss}.")
    keep_all_best = c.get("keep_all_best", False)
    keep_after = c.get("keep_after", 10000)  # void if keep_all_best False

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch)
        eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict["avg_model_loss"]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_wavernn,
            optimizer,
            scheduler,
            None,
            None,
            None,
            global_step,
            epoch,
            OUT_PATH,
            keep_all_best=keep_all_best,
            keep_after=keep_after,
            model_losses=eval_avg_loss_dict,
            scaler=scaler.state_dict() if c.mixed_precision else None,
        )
Ejemplo n.º 15
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        #eval_data, train_data = load_file_data(c.data_path, c.eval_split_size)
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    # if num_gpus > 1:
    # init_distributed(args.rank, num_gpus, args.group_id,
    #  c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup optimizers
    optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
    optimizer_disc = RAdam(model_disc.parameters(),
                           lr=c.lr_disc,
                           weight_decay=0)

    scaler_G = GradScaler()
    scaler_D = GradScaler()

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if 'lr_scheduler_gen' in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if 'lr_scheduler_disc' in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint['model'])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint['optimizer'])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint['model_disc'])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
            if 'scheduler' in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler_gen.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler_gen.optimizer = optimizer_gen
            if 'scheduler_disc' in checkpoint:
                print(" > Restoring Discriminator LR Scheduler...")
                scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
                scheduler_disc.optimizer = optimizer_disc
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer_gen.param_groups:
            group['lr'] = c.lr_gen

        for group in optimizer_disc.param_groups:
            group['lr'] = c.lr_disc

        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # DISTRUBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_gen, criterion_gen, optimizer_gen,
                               model_disc, criterion_disc, optimizer_disc,
                               scaler_G, scaler_D, scheduler_gen,
                               scheduler_disc, ap, global_step, epoch)
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(target_loss,
                                    best_loss,
                                    model_gen,
                                    optimizer_gen,
                                    scheduler_gen,
                                    model_disc,
                                    optimizer_disc,
                                    scheduler_disc,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=eval_avg_loss_dict)
def setup(USE_CUDA):
    TEXT = ''
    OUT_PATH = 'tests-audios/'
    # create output path
    os.makedirs(OUT_PATH, exist_ok=True)

    SPEAKER_FILEID = None  # if None use the first embedding from speakers.json

    # model vars
    MODEL_PATH = 'best_model.pth.tar'
    CONFIG_PATH = 'config.json'

    # vocoder vars
    VOCODER_PATH = ''
    VOCODER_CONFIG_PATH = ''

    # load the config
    C = load_config(CONFIG_PATH)
    C.forward_attn_mask = True

    # load the audio processor
    ap = AudioProcessor(**C.audio)

    # if the vocabulary was passed, replace the default
    if 'characters' in C.keys():
        symbols, phonemes = make_symbols(**C.characters)

    speaker_embedding = None
    speaker_embedding_dim = None
    num_speakers = 0
    # load speakers
    if SPEAKER_JSON != '':
        speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))
        num_speakers = len(speaker_mapping)
        if C.use_external_speaker_embedding_file:
            if SPEAKER_FILEID is not None:
                speaker_embedding = speaker_mapping[SPEAKER_FILEID][
                    'embedding']
            else:  # if speaker_fileid is not specificated use the first sample in speakers.json
                choise_speaker = list(speaker_mapping.keys())[0]
                print(" Speaker: ",
                      choise_speaker.split('_')[0], 'was chosen automatically',
                      "(this speaker seen in training)")
                speaker_embedding = speaker_mapping[choise_speaker][
                    'embedding']
            speaker_embedding_dim = len(speaker_embedding)
            print(speaker_embedding_dim)

    # load the model
    num_chars = len(phonemes) if C.use_phonemes else len(symbols)
    model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim)
    cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
    model.load_state_dict(cp['model'])
    model.eval()

    if USE_CUDA:
        model.cuda()

    model.decoder.set_r(cp['r'])

    # load vocoder model
    if VOCODER_PATH != "":
        VC = load_config(VOCODER_CONFIG_PATH)
        vocoder_model = setup_generator(VC)
        vocoder_model.load_state_dict(
            torch.load(VOCODER_PATH, map_location="cpu")["model"])
        vocoder_model.remove_weight_norm()
        if USE_CUDA:
            vocoder_model.cuda()
        vocoder_model.eval()
    else:
        vocoder_model = None
        VC = None
    # synthesize voice
    use_griffin_lim = VOCODER_PATH == ""

    if not C.use_external_speaker_embedding_file:
        if SPEAKER_FILEID.isdigit():
            SPEAKER_FILEID = int(SPEAKER_FILEID)
        else:
            SPEAKER_FILEID = None
    else:
        SPEAKER_FILEID = None

    print("Using vocoder:", vocoder_model)
    return model, vocoder_model, C, ap, SPEAKER_FILEID, speaker_embedding
Ejemplo n.º 17
0
parser.add_argument('--config_path',
                    type=str,
                    help='Path to config file of torch model.')
parser.add_argument(
    '--output_path',
    type=str,
    help='path to output file including file name to save TF model.')
args = parser.parse_args()

# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0

# init torch model
model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path,
                        map_location=torch.device('cpu'))
state_dict = checkpoint['model']
model.load_state_dict(state_dict)
model.remove_weight_norm()
state_dict = model.state_dict()

# init tf model
model_tf = setup_tf_generator(c)

common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
# get tf_model graph by passing an input
# B x D x T
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
mel_pred = model_tf(dummy_input, training=False)
Ejemplo n.º 18
0
    def load(self):
        # load the config
        C = load_config(self.config_path)
        self.config = C

        # Resolve scale_stats path
        stats_path = C.audio.get("stats_path")
        if stats_path and not os.path.isfile(stats_path):
            # Look for stats next to config
            model_stats_path = os.path.join(os.path.dirname(self.config_path),
                                            "scale_stats.npy")
            if os.path.isfile(model_stats_path):
                # Patch config
                C.audio["stats_path"] = model_stats_path
            else:
                _LOGGER.warning("No scale stats found at %s",
                                C.audio["stats_path"])
                C.audio["stats_path"] = ""

        C.forward_attn_mask = True

        if "gst" not in C.keys():
            # Patch config
            gst = {
                "gst_use_speaker_embedding": False,
                "gst_style_input": None,
                "gst_embedding_dim": 512,
                "gst_num_heads": 4,
                "gst_style_tokens": 10,
            }

            C["gst"] = gst
            setattr(C, "gst", gst)

        if "use_external_speaker_embedding_file" not in C.keys():
            C["use_external_speaker_embedding_file"] = False
            setattr(C, "use_external_speaker_embedding_file", False)

        if "gst_use_speaker_embedding" not in C.gst:
            C.gst["gst_use_speaker_embedding"] = False

        # load the audio processor
        ap = AudioProcessor(**C.audio)
        self.ap = ap

        # if the vocabulary was passed, replace the default
        if "characters" in C.keys():
            symbols, phonemes = make_symbols(**C.characters)
        else:
            from TTS.tts.utils.text.symbols import phonemes, symbols

        speaker_embedding = None
        speaker_embedding_dim = None
        num_speakers = 0

        # load speakers
        if self.speakers_json != "":
            speaker_mapping = json.load(open(self.speakers_json, "r"))
            num_speakers = len(speaker_mapping)
            if C.use_external_speaker_embedding_file:
                if self.speaker_fileid is not None:
                    speaker_embedding = speaker_mapping[
                        self.speaker_fileid]["embedding"]
                else:  # if speaker_fileid is not specificated use the first sample in speakers.json
                    speaker_embedding = speaker_mapping[list(
                        speaker_mapping.keys())[0]]["embedding"]
                speaker_embedding_dim = len(speaker_embedding)

        self.speaker_embedding = speaker_embedding

        # load the model
        num_chars = len(phonemes) if C.use_phonemes else len(symbols)
        model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim)
        cp = torch.load(self.model_path, map_location=torch.device("cpu"))
        model.load_state_dict(cp["model"])
        model.eval()
        if self.use_cuda:
            model.cuda()

        if hasattr(model.decoder, "set_r"):
            model.decoder.set_r(cp["r"])

        self.model = model

        # load vocoder model
        if self.vocoder_path:
            VC = load_config(self.vocoder_config_path)
            # Resolve scale_stats path
            stats_path = VC.audio.get("stats_path")
            if stats_path and not os.path.isfile(stats_path):
                # Look for stats next to config
                vocoder_stats_path = os.path.join(
                    os.path.dirname(self.vocoder_config_path),
                    "scale_stats.npy")
                if os.path.isfile(vocoder_stats_path):
                    # Patch config
                    VC.audio["stats_path"] = vocoder_stats_path
                else:
                    # Try next to TTS config
                    vocoder_stats_path = os.path.join(
                        os.path.dirname(self.config_path), "scale_stats.npy")
                    if os.path.isfile(vocoder_stats_path):
                        # Patch config
                        VC.audio["stats_path"] = vocoder_stats_path
                    else:
                        _LOGGER.warning("No vocoder scale stats found at %s",
                                        VC.audio["stats_path"])
                        VC.audio["stats_path"] = ""

            self.ap_vocoder = AudioProcessor(**VC.audio)

            vocoder_model = setup_generator(VC)
            vocoder_model.load_state_dict(
                torch.load(self.vocoder_path, map_location="cpu")["model"])
            vocoder_model.remove_weight_norm()
            vocoder_model.inference_padding = 0
            if self.use_cuda:
                vocoder_model.cuda()
            vocoder_model.eval()

            if hasattr(vocoder_model, "compute_noise_level"):
                noise_schedule_path = os.path.join(
                    os.path.dirname(self.vocoder_path), "noise_schedule.npy")
                if os.path.isfile(noise_schedule_path):
                    _LOGGER.debug("Loading noise schedule from %s",
                                  noise_schedule_path)
                    beta = np.load(noise_schedule_path,
                                   allow_pickle=True).tolist()["beta"]
                else:
                    # Use if not computed noise schedule with tune_wavegrad
                    _LOGGER.debug("Using default noise schedule")
                    beta = np.linspace(1e-6, 0.01, self.wavegrad_iters)

                vocoder_model.compute_noise_level(beta)
        else:
            vocoder_model = None
            VC = None
            self.ap_vocoder = None

        self.vocoder_model = vocoder_model
        self.vocoder_config = VC

        # synthesize voice
        self.use_griffin_lim = self.vocoder_model is None

        if not C.use_external_speaker_embedding_file:
            if self.speaker_fileid and self.speaker_fileid.isdigit():
                self.speaker_fileid = int(self.speaker_fileid)
            else:
                self.speaker_fileid = None
        else:
            self.speaker_fileid = None

        if (self.gst_style is None) and ("gst" in C.keys()):
            gst_style = C.gst.get("gst_style_input", None)
        else:
            # check if gst_style string is a dict, if is dict convert  else use string
            try:
                gst_style = json.loads(self.gst_style)
                if max(map(int,
                           gst_style.keys())) >= C.gst["gst_style_tokens"]:
                    raise RuntimeError(
                        "The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}"
                        .format(max(map(int, gst_style.keys())),
                                C.gst["gst_style_tokens"]))
            except ValueError:
                gst_style = self.gst_style

        self.gst_style = gst_style

        # Pre-load language
        if C.get("phoneme_backend") == "gruut":
            load_gruut_language(C["phoneme_language"])

        # Compute scale factors in case TTS/vocoder sample rates differ
        # See: https://github.com/mozilla/TTS/issues/520
        self.scale_factors = None

        if self.ap_vocoder and (self.ap.sample_rate !=
                                self.ap_vocoder.sample_rate):
            self.scale_factors = (1, self.ap_vocoder.sample_rate /
                                  self.ap.sample_rate)
Ejemplo n.º 19
0
    # load the model
    num_chars = len(phonemes) if C.use_phonemes else len(symbols)
    model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim)
    cp = torch.load(args.model_path, map_location=torch.device('cpu'))
    model.load_state_dict(cp['model'])
    model.eval()
    if args.use_cuda:
        model.cuda()
    if is_tacotron(C):
        model.decoder.set_r(cp['r'])

    # load vocoder model
    if args.vocoder_path != "":
        VC = load_config(args.vocoder_config_path)
        vocoder_model = setup_generator(VC)
        vocoder_model.load_state_dict(
            torch.load(args.vocoder_path, map_location="cpu")["model"])
        vocoder_model.remove_weight_norm()
        if args.use_cuda:
            vocoder_model.cuda()
        vocoder_model.eval()
    else:
        vocoder_model = None
        VC = None

    # synthesize voice
    use_griffin_lim = args.vocoder_path == ""
    print(" > Text: {}".format(args.text))

    if not C.use_external_speaker_embedding_file:
Ejemplo n.º 20
0
                          conv_pad=config.conv_pad,
                          is_training=True,
                          return_segments=False,
                          use_noise_augment=False,
                          use_cache=False,
                          verbose=True)
loader = DataLoader(dataset,
                    batch_size=1,
                    shuffle=False,
                    collate_fn=dataset.collate_full_clips,
                    drop_last=False,
                    num_workers=config.num_loader_workers,
                    pin_memory=False)

# setup the model
model = setup_generator(config)
if args.use_cuda:
    model.cuda()

# setup optimization parameters
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
print(base_values)
exponents = 10**np.linspace(-6, -1, num=args.num_iter)
best_error = float('inf')
best_schedule = None
total_search_iter = len(base_values)**args.num_iter
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter),
                 total=total_search_iter):
    beta = exponents * base
    model.compute_noise_level(beta)
    for data in loader:
Ejemplo n.º 21
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])

    # setup models
    model = setup_generator(c)

    # scaler for mixed_precision
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

    # setup optimizers
    optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)

    # schedulers
    scheduler = None
    if 'lr_scheduler' in c:
        scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
        scheduler = scheduler(optimizer, **c.lr_scheduler_params)

    # setup criterion
    criterion = torch.nn.L1Loss().cuda()

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Model...")
            model.load_state_dict(checkpoint['model'])
            print(" > Restoring Optimizer...")
            optimizer.load_state_dict(checkpoint['optimizer'])
            if 'scheduler' in checkpoint:
                print(" > Restoring LR Scheduler...")
                scheduler.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler.optimizer = optimizer
            if "scaler" in checkpoint and c.mixed_precision:
                print(" > Restoring AMP Scaler...")
                scaler.load_state_dict(checkpoint["scaler"])
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer.param_groups:
            group['lr'] = c.lr

        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = DDP_th(model, device_ids=[args.rank])

    num_params = count_parameters(model)
    print(" > WaveGrad has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model, criterion, optimizer,
                               scheduler, scaler, ap, global_step,
                               epoch)
        eval_avg_loss_dict = evaluate(model, criterion, ap,
                                      global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(target_loss,
                                    best_loss,
                                    model,
                                    optimizer,
                                    scheduler,
                                    None,
                                    None,
                                    None,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=eval_avg_loss_dict,
                                    scaler=scaler.state_dict() if c.mixed_precision else None)
Ejemplo n.º 22
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio.to_dict())

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # setup optimizers
    # TODO: allow loading custom optimizers
    optimizer_gen = None
    optimizer_disc = None
    optimizer_gen = getattr(torch.optim, c.optimizer)
    optimizer_gen = optimizer_gen(model_gen.parameters(),
                                  lr=c.lr_gen,
                                  **c.optimizer_params)
    optimizer_disc = getattr(torch.optim, c.optimizer)

    if c.discriminator_model == "hifigan_discriminator":
        optimizer_disc = optimizer_disc(
            itertools.chain(model_disc.msd.parameters(),
                            model_disc.mpd.parameters()),
            lr=c.lr_disc,
            **c.optimizer_params,
        )
    else:
        optimizer_disc = optimizer_disc(model_disc.parameters(),
                                        lr=c.lr_disc,
                                        **c.optimizer_params)

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if "lr_scheduler_gen" in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if "lr_scheduler_disc" in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    if args.restore_path:
        print(f" > Restoring from {os.path.basename(args.restore_path)}...")
        checkpoint = torch.load(args.restore_path, map_location="cpu")
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint["model"])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint["optimizer"])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint["model_disc"])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
            # restore schedulers if it is a continuing training.
            if args.continue_path != "":
                if "scheduler" in checkpoint and scheduler_gen is not None:
                    print(" > Restoring Generator LR Scheduler...")
                    scheduler_gen.load_state_dict(checkpoint["scheduler"])
                    # NOTE: Not sure if necessary
                    scheduler_gen.optimizer = optimizer_gen
                if "scheduler_disc" in checkpoint and scheduler_disc is not None:
                    print(" > Restoring Discriminator LR Scheduler...")
                    scheduler_disc.load_state_dict(
                        checkpoint["scheduler_disc"])
                    scheduler_disc.optimizer = optimizer_disc
                    if c.lr_scheduler_disc == "ExponentialLR":
                        scheduler_disc.last_epoch = checkpoint["epoch"]
        except RuntimeError:
            # restore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        if args.continue_path == "":
            for group in optimizer_gen.param_groups:
                group["lr"] = c.lr_gen

            for group in optimizer_disc.param_groups:
                group["lr"] = c.lr_disc

        print(f" > Model restored from step {checkpoint['step']:d}",
              flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    # DISTRUBUTED
    if num_gpus > 1:
        model_gen = DDP_th(model_gen, device_ids=[args.rank])
        model_disc = DDP_th(model_disc, device_ids=[args.rank])

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if args.restore_step == 0 or not args.best_path:
        best_loss = float("inf")
        print(" > Starting with inf best loss.")
    else:
        print(" > Restoring best loss from "
              f"{os.path.basename(args.best_path)} ...")
        best_loss = torch.load(args.best_path,
                               map_location="cpu")["model_loss"]
        print(f" > Starting with best loss of {best_loss}.")
    keep_all_best = c.get("keep_all_best", False)
    keep_after = c.get("keep_after", 10000)  # void if keep_all_best False

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(
            model_gen,
            criterion_gen,
            optimizer_gen,
            model_disc,
            criterion_disc,
            optimizer_disc,
            scheduler_gen,
            scheduler_disc,
            ap,
            global_step,
            epoch,
        )
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_gen,
            optimizer_gen,
            scheduler_gen,
            model_disc,
            optimizer_disc,
            scheduler_disc,
            global_step,
            epoch,
            OUT_PATH,
            keep_all_best=keep_all_best,
            keep_after=keep_after,
            model_losses=eval_avg_loss_dict,
        )