Beispiel #1
0
def main():
    """Run `tts` model training directly by a `config.json` file."""
    # init trainer args
    train_args = TrainVocoderArgs()
    parser = train_args.init_argparse(arg_prefix="")

    # override trainer args from comman-line args
    args, config_overrides = parser.parse_known_args()
    train_args.parse_args(args)

    # load config.json and register
    if args.config_path or args.continue_path:
        if args.config_path:
            # init from a file
            config = load_config(args.config_path)
            if len(config_overrides) > 0:
                config.parse_known_args(config_overrides, relaxed_parser=True)
        elif args.continue_path:
            # continue from a prev experiment
            config = load_config(
                os.path.join(args.continue_path, "config.json"))
            if len(config_overrides) > 0:
                config.parse_known_args(config_overrides, relaxed_parser=True)
        else:
            # init from console args
            from TTS.config.shared_configs import BaseTrainingConfig  # pylint: disable=import-outside-toplevel

            config_base = BaseTrainingConfig()
            config_base.parse_known_args(config_overrides)
            config = register_config(config_base.model)()

    # load training samples
    if "feature_path" in config and config.feature_path:
        # load pre-computed features
        print(f" > Loading features from: {config.feature_path}")
        eval_samples, train_samples = load_wav_feat_data(
            config.data_path, config.feature_path, config.eval_split_size)
    else:
        # load data raw wav files
        eval_samples, train_samples = load_wav_data(config.data_path,
                                                    config.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**config.audio)

    # init the model from config
    model = setup_model(config)

    # init the trainer and 🚀
    trainer = Trainer(
        train_args,
        config,
        config.output_path,
        model=model,
        train_samples=train_samples,
        eval_samples=eval_samples,
        training_assets={"audio_processor": ap},
        parse_command_line_args=False,
    )
    trainer.fit()
Beispiel #2
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])

    # setup models
    model = setup_generator(c)

    # scaler for mixed_precision
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

    # setup optimizers
    optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)

    # schedulers
    scheduler = None
    if 'lr_scheduler' in c:
        scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
        scheduler = scheduler(optimizer, **c.lr_scheduler_params)

    # setup criterion
    criterion = torch.nn.L1Loss().cuda()

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Model...")
            model.load_state_dict(checkpoint['model'])
            print(" > Restoring Optimizer...")
            optimizer.load_state_dict(checkpoint['optimizer'])
            if 'scheduler' in checkpoint:
                print(" > Restoring LR Scheduler...")
                scheduler.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler.optimizer = optimizer
            if "scaler" in checkpoint and c.mixed_precision:
                print(" > Restoring AMP Scaler...")
                scaler.load_state_dict(checkpoint["scaler"])
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer.param_groups:
            group['lr'] = c.lr

        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model.cuda()
        criterion.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model = DDP_th(model, device_ids=[args.rank])

    num_params = count_parameters(model)
    print(" > WaveGrad has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model, criterion, optimizer,
                               scheduler, scaler, ap, global_step,
                               epoch)
        eval_avg_loss_dict = evaluate(model, criterion, ap,
                                      global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(target_loss,
                                    best_loss,
                                    model,
                                    optimizer,
                                    scheduler,
                                    None,
                                    None,
                                    None,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=eval_avg_loss_dict,
                                    scaler=scaler.state_dict() if c.mixed_precision else None)
Beispiel #3
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup optimizers
    optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
    optimizer_disc = RAdam(model_disc.parameters(),
                           lr=c.lr_disc,
                           weight_decay=0)

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if 'lr_scheduler_gen' in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if 'lr_scheduler_disc' in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if args.restore_path:
        print(f" > Restoring from {os.path.basename(args.restore_path)}...")
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint['model'])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint['optimizer'])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint['model_disc'])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
            if 'scheduler' in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler_gen.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler_gen.optimizer = optimizer_gen
            if 'scheduler_disc' in checkpoint:
                print(" > Restoring Discriminator LR Scheduler...")
                scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
                scheduler_disc.optimizer = optimizer_disc
        except RuntimeError:
            # restore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer_gen.param_groups:
            group['lr'] = c.lr_gen

        for group in optimizer_disc.param_groups:
            group['lr'] = c.lr_disc

        print(f" > Model restored from step {checkpoint['step']:d}",
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # DISTRUBUTED
    if num_gpus > 1:
        model_gen = DDP_th(model_gen, device_ids=[args.rank])
        model_disc = DDP_th(model_disc, device_ids=[args.rank])

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if args.restore_step == 0 or not args.best_path:
        best_loss = float('inf')
        print(" > Starting with inf best loss.")
    else:
        print(" > Restoring best loss from "
              f"{os.path.basename(args.best_path)} ...")
        best_loss = torch.load(args.best_path,
                               map_location='cpu')['model_loss']
        print(f" > Starting with best loss of {best_loss}.")
    keep_all_best = c.get('keep_all_best', False)
    keep_after = c.get('keep_after', 10000)  # void if keep_all_best False

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_gen, criterion_gen, optimizer_gen,
                               model_disc, criterion_disc, optimizer_disc,
                               scheduler_gen, scheduler_disc, ap, global_step,
                               epoch)
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_gen,
            optimizer_gen,
            scheduler_gen,
            model_disc,
            optimizer_disc,
            scheduler_disc,
            global_step,
            epoch,
            OUT_PATH,
            keep_all_best=keep_all_best,
            keep_after=keep_after,
            model_losses=eval_avg_loss_dict,
        )
Beispiel #4
0
    use_noise_augment=True,
    eval_split_size=10,
    print_step=25,
    print_eval=False,
    mixed_precision=False,
    lr_gen=1e-4,
    lr_disc=1e-4,
    data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
    output_path=output_path,
)

# init audio processor
ap = AudioProcessor(**config.audio.to_dict())

# load training samples
eval_samples, train_samples = load_wav_data(config.data_path,
                                            config.eval_split_size)

# init model
model = GAN(config)

# init the trainer and 🚀
trainer = Trainer(
    TrainingArgs(),
    config,
    output_path,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
    training_assets={"audio_processor": ap},
)
trainer.fit()
Beispiel #5
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # print(f" > Loading wavs from: {c.data_path}")
    # if c.feature_path is not None:
    #     print(f" > Loading features from: {c.feature_path}")
    #     eval_data, train_data = load_wav_feat_data(
    #         c.data_path, c.feature_path, c.eval_split_size
    #     )
    # else:
    #     mel_feat_path = os.path.join(OUT_PATH, "mel")
    #     feat_data = find_feat_files(mel_feat_path)
    #     if feat_data:
    #         print(f" > Loading features from: {mel_feat_path}")
    #         eval_data, train_data = load_wav_feat_data(
    #             c.data_path, mel_feat_path, c.eval_split_size
    #         )
    #     else:
    #         print(" > No feature data found. Preprocessing...")
    #         # preprocessing feature data from given wav files
    #         preprocess_wav_files(OUT_PATH, CONFIG, ap)
    #         eval_data, train_data = load_wav_feat_data(
    #             c.data_path, mel_feat_path, c.eval_split_size
    #         )

    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
    # setup model
    model_wavernn = setup_wavernn(c)

    # setup amp scaler
    scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

    # define train functions
    if c.mode == "mold":
        criterion = discretized_mix_logistic_loss
    elif c.mode == "gauss":
        criterion = gaussian_loss
    elif isinstance(c.mode, int):
        criterion = torch.nn.CrossEntropyLoss()

    if use_cuda:
        model_wavernn.cuda()
        if isinstance(c.mode, int):
            criterion.cuda()

    optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)

    scheduler = None
    if "lr_scheduler" in c:
        scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
        scheduler = scheduler(optimizer, **c.lr_scheduler_params)
    # slow start for the first 5 epochs
    # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
    # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # restore any checkpoint
    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location="cpu")
        try:
            print(" > Restoring Model...")
            model_wavernn.load_state_dict(checkpoint["model"])
            print(" > Restoring Optimizer...")
            optimizer.load_state_dict(checkpoint["optimizer"])
            if "scheduler" in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler.load_state_dict(checkpoint["scheduler"])
                scheduler.optimizer = optimizer
            if "scaler" in checkpoint and c.mixed_precision:
                print(" > Restoring AMP Scaler...")
                scaler.load_state_dict(checkpoint["scaler"])
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_wavernn.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model_wavernn.load_state_dict(model_dict)

        print(" > Model restored from step %d" % checkpoint["step"],
              flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    # DISTRIBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_parameters = count_parameters(model_wavernn)
    print(" > Model has {} parameters".format(num_parameters), flush=True)

    if "best_loss" not in locals():
        best_loss = float("inf")

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_wavernn, optimizer, criterion, scheduler,
                               scaler, ap, global_step, epoch)
        eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap,
                                      global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict["avg_model_loss"]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_wavernn,
            optimizer,
            scheduler,
            None,
            None,
            None,
            global_step,
            epoch,
            OUT_PATH,
            model_losses=eval_avg_loss_dict,
            scaler=scaler.state_dict() if c.mixed_precision else None)
Beispiel #6
0
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments,
                     use_noise_augment, use_cache, num_workers):
    ''' run dataloader with given parameters and check conditions '''
    ap = AudioProcessor(**C.audio)
    _, train_items = load_wav_data(test_data_path, 10)
    dataset = GANDataset(ap,
                         train_items,
                         seq_len=seq_len,
                         hop_len=hop_len,
                         pad_short=2000,
                         conv_pad=conv_pad,
                         return_segments=return_segments,
                         use_noise_augment=use_noise_augment,
                         use_cache=use_cache)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True)

    max_iter = 10
    count_iter = 0

    # return random segments or return the whole audio
    if return_segments:
        for item1, _ in loader:
            feat1, wav1 = item1
            # feat2, wav2 = item2
            expected_feat_shape = (batch_size, ap.num_mels,
                                   seq_len // hop_len + conv_pad * 2)

            # check shapes
            assert np.all(feat1.shape == expected_feat_shape
                          ), f" [!] {feat1.shape} vs {expected_feat_shape}"
            assert (feat1.shape[2] - conv_pad * 2) * hop_len == wav1.shape[2]

            # check feature vs audio match
            if not use_noise_augment:
                for idx in range(batch_size):
                    audio = wav1[idx].squeeze()
                    feat = feat1[idx]
                    mel = ap.melspectrogram(audio)
                    # the first 2 and the last frame is skipped due to the padding
                    # applied in spec. computation.
                    assert (feat - mel[:, :feat1.shape[-1]])[:, 2:-1].sum(
                    ) == 0, f' [!] {(feat - mel[:, :feat1.shape[-1]])[:, 2:-1].sum()}'

            count_iter += 1
            # if count_iter == max_iter:
            #     break
    else:
        for item in loader:
            feat, wav = item
            expected_feat_shape = (batch_size, ap.num_mels,
                                   (wav.shape[-1] // hop_len) + (conv_pad * 2))
            assert np.all(feat.shape == expected_feat_shape
                          ), f" [!] {feat.shape} vs {expected_feat_shape}"
            assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
            count_iter += 1
            if count_iter == max_iter:
                break
Beispiel #7
0
    '--search_depth',
    type=int,
    default=3,
    help=
    'Search granularity. Increasing this increases the run-time exponentially.'
)

# load config
args = parser.parse_args()
config = load_config(args.config_path)

# setup audio processor
ap = AudioProcessor(**config.audio)

# load dataset
_, train_data = load_wav_data(args.data_path, 0)
train_data = train_data[:args.num_samples]
dataset = WaveGradDataset(ap=ap,
                          items=train_data,
                          seq_len=-1,
                          hop_len=ap.hop_length,
                          pad_short=config.pad_short,
                          conv_pad=config.conv_pad,
                          is_training=True,
                          return_segments=False,
                          use_noise_augment=False,
                          use_cache=False,
                          verbose=True)
loader = DataLoader(dataset,
                    batch_size=1,
                    shuffle=False,
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_pairs,
                     return_segments, use_noise_augment, use_cache,
                     num_workers):
    """Run dataloader with given parameters and check conditions"""
    ap = AudioProcessor(**C.audio)
    _, train_items = load_wav_data(test_data_path, 10)
    dataset = GANDataset(
        ap,
        train_items,
        seq_len=seq_len,
        hop_len=hop_len,
        pad_short=2000,
        conv_pad=conv_pad,
        return_pairs=return_pairs,
        return_segments=return_segments,
        use_noise_augment=use_noise_augment,
        use_cache=use_cache,
    )
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True)

    max_iter = 10
    count_iter = 0

    def check_item(feat, wav):
        """Pass a single pair of features and waveform"""
        feat = feat.numpy()
        wav = wav.numpy()
        expected_feat_shape = (batch_size, ap.num_mels,
                               seq_len // hop_len + conv_pad * 2)

        # check shapes
        assert np.all(feat.shape == expected_feat_shape
                      ), f" [!] {feat.shape} vs {expected_feat_shape}"
        assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]

        # check feature vs audio match
        if not use_noise_augment:
            for idx in range(batch_size):
                audio = wav[idx].squeeze()
                feat = feat[idx]
                mel = ap.melspectrogram(audio)
                # the first 2 and the last 2 frames are skipped due to the padding
                # differences in stft
                max_diff = abs((feat - mel[:, :feat.shape[-1]])[:, 2:-2]).max()
                assert max_diff <= 1e-6, f" [!] {max_diff}"

    # return random segments or return the whole audio
    if return_segments:
        if return_pairs:
            for item1, item2 in loader:
                feat1, wav1 = item1
                feat2, wav2 = item2
                check_item(feat1, wav1)
                check_item(feat2, wav2)
                count_iter += 1
        else:
            for item1 in loader:
                feat1, wav1 = item1
                check_item(feat1, wav1)
                count_iter += 1
    else:
        for item in loader:
            feat, wav = item
            expected_feat_shape = (batch_size, ap.num_mels,
                                   (wav.shape[-1] // hop_len) + (conv_pad * 2))
            assert np.all(feat.shape == expected_feat_shape
                          ), f" [!] {feat.shape} vs {expected_feat_shape}"
            assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
            count_iter += 1
            if count_iter == max_iter:
                break
Beispiel #9
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        #eval_data, train_data = load_file_data(c.data_path, c.eval_split_size)
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio)

    # DISTRUBUTED
    # if num_gpus > 1:
    # init_distributed(args.rank, num_gpus, args.group_id,
    #  c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup optimizers
    optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
    optimizer_disc = RAdam(model_disc.parameters(),
                           lr=c.lr_disc,
                           weight_decay=0)

    scaler_G = GradScaler()
    scaler_D = GradScaler()

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if 'lr_scheduler_gen' in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if 'lr_scheduler_disc' in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if args.restore_path:
        checkpoint = torch.load(args.restore_path, map_location='cpu')
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint['model'])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint['optimizer'])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint['model_disc'])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
            if 'scheduler' in checkpoint:
                print(" > Restoring Generator LR Scheduler...")
                scheduler_gen.load_state_dict(checkpoint['scheduler'])
                # NOTE: Not sure if necessary
                scheduler_gen.optimizer = optimizer_gen
            if 'scheduler_disc' in checkpoint:
                print(" > Restoring Discriminator LR Scheduler...")
                scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
                scheduler_disc.optimizer = optimizer_disc
        except RuntimeError:
            # retore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model'], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        for group in optimizer_gen.param_groups:
            group['lr'] = c.lr_gen

        for group in optimizer_disc.param_groups:
            group['lr'] = c.lr_disc

        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # DISTRUBUTED
    # if num_gpus > 1:
    #     model = apply_gradient_allreduce(model)

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if 'best_loss' not in locals():
        best_loss = float('inf')

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(model_gen, criterion_gen, optimizer_gen,
                               model_disc, criterion_disc, optimizer_disc,
                               scaler_G, scaler_D, scheduler_gen,
                               scheduler_disc, ap, global_step, epoch)
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(target_loss,
                                    best_loss,
                                    model_gen,
                                    optimizer_gen,
                                    scheduler_gen,
                                    model_disc,
                                    optimizer_disc,
                                    scheduler_disc,
                                    global_step,
                                    epoch,
                                    OUT_PATH,
                                    model_losses=eval_avg_loss_dict)
Beispiel #10
0
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global train_data, eval_data
    print(f" > Loading wavs from: {c.data_path}")
    if c.feature_path is not None:
        print(f" > Loading features from: {c.feature_path}")
        eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
                                                   c.eval_split_size)
    else:
        eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)

    # setup audio processor
    ap = AudioProcessor(**c.audio.to_dict())

    # DISTRUBUTED
    if num_gpus > 1:
        init_distributed(args.rank, num_gpus, args.group_id,
                         c.distributed["backend"], c.distributed["url"])

    # setup models
    model_gen = setup_generator(c)
    model_disc = setup_discriminator(c)

    # setup criterion
    criterion_gen = GeneratorLoss(c)
    criterion_disc = DiscriminatorLoss(c)

    if use_cuda:
        model_gen.cuda()
        criterion_gen.cuda()
        model_disc.cuda()
        criterion_disc.cuda()

    # setup optimizers
    # TODO: allow loading custom optimizers
    optimizer_gen = None
    optimizer_disc = None
    optimizer_gen = getattr(torch.optim, c.optimizer)
    optimizer_gen = optimizer_gen(model_gen.parameters(),
                                  lr=c.lr_gen,
                                  **c.optimizer_params)
    optimizer_disc = getattr(torch.optim, c.optimizer)

    if c.discriminator_model == "hifigan_discriminator":
        optimizer_disc = optimizer_disc(
            itertools.chain(model_disc.msd.parameters(),
                            model_disc.mpd.parameters()),
            lr=c.lr_disc,
            **c.optimizer_params,
        )
    else:
        optimizer_disc = optimizer_disc(model_disc.parameters(),
                                        lr=c.lr_disc,
                                        **c.optimizer_params)

    # schedulers
    scheduler_gen = None
    scheduler_disc = None
    if "lr_scheduler_gen" in c:
        scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
        scheduler_gen = scheduler_gen(optimizer_gen,
                                      **c.lr_scheduler_gen_params)
    if "lr_scheduler_disc" in c:
        scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
        scheduler_disc = scheduler_disc(optimizer_disc,
                                        **c.lr_scheduler_disc_params)

    if args.restore_path:
        print(f" > Restoring from {os.path.basename(args.restore_path)}...")
        checkpoint = torch.load(args.restore_path, map_location="cpu")
        try:
            print(" > Restoring Generator Model...")
            model_gen.load_state_dict(checkpoint["model"])
            print(" > Restoring Generator Optimizer...")
            optimizer_gen.load_state_dict(checkpoint["optimizer"])
            print(" > Restoring Discriminator Model...")
            model_disc.load_state_dict(checkpoint["model_disc"])
            print(" > Restoring Discriminator Optimizer...")
            optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
            # restore schedulers if it is a continuing training.
            if args.continue_path != "":
                if "scheduler" in checkpoint and scheduler_gen is not None:
                    print(" > Restoring Generator LR Scheduler...")
                    scheduler_gen.load_state_dict(checkpoint["scheduler"])
                    # NOTE: Not sure if necessary
                    scheduler_gen.optimizer = optimizer_gen
                if "scheduler_disc" in checkpoint and scheduler_disc is not None:
                    print(" > Restoring Discriminator LR Scheduler...")
                    scheduler_disc.load_state_dict(
                        checkpoint["scheduler_disc"])
                    scheduler_disc.optimizer = optimizer_disc
                    if c.lr_scheduler_disc == "ExponentialLR":
                        scheduler_disc.last_epoch = checkpoint["epoch"]
        except RuntimeError:
            # restore only matching layers.
            print(" > Partial model initialization...")
            model_dict = model_gen.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model"], c)
            model_gen.load_state_dict(model_dict)

            model_dict = model_disc.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
            model_disc.load_state_dict(model_dict)
            del model_dict

        # reset lr if not countinuining training.
        if args.continue_path == "":
            for group in optimizer_gen.param_groups:
                group["lr"] = c.lr_gen

            for group in optimizer_disc.param_groups:
                group["lr"] = c.lr_disc

        print(f" > Model restored from step {checkpoint['step']:d}",
              flush=True)
        args.restore_step = checkpoint["step"]
    else:
        args.restore_step = 0

    # DISTRUBUTED
    if num_gpus > 1:
        model_gen = DDP_th(model_gen, device_ids=[args.rank])
        model_disc = DDP_th(model_disc, device_ids=[args.rank])

    num_params = count_parameters(model_gen)
    print(" > Generator has {} parameters".format(num_params), flush=True)
    num_params = count_parameters(model_disc)
    print(" > Discriminator has {} parameters".format(num_params), flush=True)

    if args.restore_step == 0 or not args.best_path:
        best_loss = float("inf")
        print(" > Starting with inf best loss.")
    else:
        print(" > Restoring best loss from "
              f"{os.path.basename(args.best_path)} ...")
        best_loss = torch.load(args.best_path,
                               map_location="cpu")["model_loss"]
        print(f" > Starting with best loss of {best_loss}.")
    keep_all_best = c.get("keep_all_best", False)
    keep_after = c.get("keep_after", 10000)  # void if keep_all_best False

    global_step = args.restore_step
    for epoch in range(0, c.epochs):
        c_logger.print_epoch_start(epoch, c.epochs)
        _, global_step = train(
            model_gen,
            criterion_gen,
            optimizer_gen,
            model_disc,
            criterion_disc,
            optimizer_disc,
            scheduler_gen,
            scheduler_disc,
            ap,
            global_step,
            epoch,
        )
        eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
                                      criterion_disc, ap, global_step, epoch)
        c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
        target_loss = eval_avg_loss_dict[c.target_loss]
        best_loss = save_best_model(
            target_loss,
            best_loss,
            model_gen,
            optimizer_gen,
            scheduler_gen,
            model_disc,
            optimizer_disc,
            scheduler_disc,
            global_step,
            epoch,
            OUT_PATH,
            keep_all_best=keep_all_best,
            keep_after=keep_after,
            model_losses=eval_avg_loss_dict,
        )