Пример #1
0
def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_melgan_generator_args(**dict_g)
    args_d = make_melgan_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(batch_size, args_g["in_channels"],
                    batch_length // np.prod(
                        args_g["upsample_scales"]))
    model_g = MelGANGenerator(**args_g)
    model_d = MelGANMultiScaleDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(c)
    p_hat = model_d(y_hat)
    y, y_hat = y.squeeze(1), y_hat.squeeze(1)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    adv_loss = 0.0
    for i in range(len(p_hat)):
        adv_loss += F.mse_loss(
            p_hat[i][-1], p_hat[i][-1].new_ones(p_hat[i][-1].size()))
    adv_loss /= (i + 1)
    with torch.no_grad():
        p = model_d(y.unsqueeze(1))
    fm_loss = 0.0
    for i in range(len(p_hat)):
        for j in range(len(p_hat[i]) - 1):
            fm_loss += F.l1_loss(p_hat[i][j], p[i][j].detach())
    fm_loss /= (i + 1) * j
    loss_g = adv_loss + aux_loss + fm_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach()
    p = model_d(y)
    p_hat = model_d(y_hat)
    real_loss = 0.0
    fake_loss = 0.0
    for i in range(len(p)):
        real_loss += F.mse_loss(
            p[i][-1], p[i][-1].new_ones(p[i][-1].size()))
        fake_loss += F.mse_loss(
            p_hat[i][-1], p_hat[i][-1].new_zeros(p_hat[i][-1].size()))
    real_loss /= (i + 1)
    fake_loss /= (i + 1)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
Пример #2
0
def test_hifigan_trainable(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 2**13
    args_g = make_hifigan_generator_args(**dict_g)
    args_d = make_hifigan_multi_scale_multi_period_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size,
        args_g["in_channels"],
        batch_length // np.prod(args_g["upsample_scales"]),
    )
    model_g = HiFiGANGenerator(**args_g)
    model_d = HiFiGANMultiScaleMultiPeriodDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    feat_match_criterion = FeatureMatchLoss(
        average_by_layers=False,
        average_by_discriminators=False,
        include_final_outputs=True,
    )
    gen_adv_criterion = GeneratorAdversarialLoss(
        average_by_discriminators=False, )
    dis_adv_criterion = DiscriminatorAdversarialLoss(
        average_by_discriminators=False, )
    optimizer_g = torch.optim.AdamW(model_g.parameters())
    optimizer_d = torch.optim.AdamW(model_d.parameters())

    # check generator trainable
    y_hat = model_g(c)
    p_hat = model_d(y_hat)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    adv_loss = gen_adv_criterion(p_hat)
    with torch.no_grad():
        p = model_d(y)
    fm_loss = feat_match_criterion(p_hat, p)
    loss_g = adv_loss + aux_loss + fm_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    p = model_d(y)
    p_hat = model_d(y_hat.detach())
    real_loss, fake_loss = dis_adv_criterion(p_hat, p)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()

    print(model_d)
    print(model_g)
Пример #3
0
def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_melgan_generator_args(**dict_g)
    args_d = make_melgan_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size,
        args_g["in_channels"],
        batch_length // np.prod(args_g["upsample_scales"]),
    )
    model_g = MelGANGenerator(**args_g)
    model_d = MelGANMultiScaleDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    feat_match_criterion = FeatureMatchLoss()
    gen_adv_criterion = GeneratorAdversarialLoss()
    dis_adv_criterion = DiscriminatorAdversarialLoss()
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(c)
    p_hat = model_d(y_hat)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    adv_loss = gen_adv_criterion(p_hat)
    with torch.no_grad():
        p = model_d(y)
    fm_loss = feat_match_criterion(p_hat, p)
    loss_g = adv_loss + aux_loss + fm_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    p = model_d(y)
    p_hat = model_d(y_hat.detach())
    real_loss, fake_loss = dis_adv_criterion(p_hat, p)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
def test_parallel_wavegan_with_residual_discriminator_trainable(
        dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_generator_args(**dict_g)
    args_d = make_residual_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    z = torch.randn(batch_size, 1, batch_length)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size,
        args_g["aux_channels"],
        batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) +
        2 * args_g["aux_context_window"],
    )
    model_g = ParallelWaveGANGenerator(**args_g)
    model_d = ResidualParallelWaveGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    gen_adv_criterion = GeneratorAdversarialLoss()
    dis_adv_criterion = DiscriminatorAdversarialLoss()
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(z, c)
    p_hat = model_d(y_hat)
    adv_loss = gen_adv_criterion(p_hat)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    p = model_d(y)
    p_hat = model_d(y_hat.detach())
    real_loss, fake_loss = dis_adv_criterion(p_hat, p)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
Пример #5
0
def test_style_melgan_trainable(dict_g, dict_d, dict_loss, loss_type):
    # setup
    args_g = make_style_melgan_generator_args(**dict_g)
    args_d = make_style_melgan_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    batch_size = 4
    batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod(
        args_g["upsample_scales"]
    )
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size,
        args_g["aux_channels"],
        batch_length // np.prod(args_g["upsample_scales"]),
    )
    model_g = StyleMelGANGenerator(**args_g)
    model_d = StyleMelGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    gen_adv_criterion = GeneratorAdversarialLoss(loss_type=loss_type)
    dis_adv_criterion = DiscriminatorAdversarialLoss(loss_type=loss_type)
    optimizer_g = torch.optim.Adam(model_g.parameters())
    optimizer_d = torch.optim.Adam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(c)
    p_hat = model_d(y_hat)
    adv_loss = gen_adv_criterion(p_hat)
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    p = model_d(y)
    p_hat = model_d(y_hat.detach())
    real_loss, fake_loss = dis_adv_criterion(p_hat, p)
    loss_d = real_loss + fake_loss
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss):
    # setup
    batch_size = 4
    batch_length = 4096
    args_g = make_generator_args(**dict_g)
    args_d = make_discriminator_args(**dict_d)
    args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
    z = torch.randn(batch_size, 1, batch_length)
    y = torch.randn(batch_size, 1, batch_length)
    c = torch.randn(
        batch_size, args_g["aux_channels"],
        batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]) +
        2 * args_g["aux_context_window"])
    model_g = ParallelWaveGANGenerator(**args_g)
    model_d = ParallelWaveGANDiscriminator(**args_d)
    aux_criterion = MultiResolutionSTFTLoss(**args_loss)
    optimizer_g = RAdam(model_g.parameters())
    optimizer_d = RAdam(model_d.parameters())

    # check generator trainable
    y_hat = model_g(z, c)
    p_hat = model_d(y_hat)
    y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1)
    adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size()))
    sc_loss, mag_loss = aux_criterion(y_hat, y)
    aux_loss = sc_loss + mag_loss
    loss_g = adv_loss + aux_loss
    optimizer_g.zero_grad()
    loss_g.backward()
    optimizer_g.step()

    # check discriminator trainable
    y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach()
    p = model_d(y)
    p_hat = model_d(y_hat)
    p, p_hat = p.squeeze(1), p_hat.squeeze(1)
    loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss(
        p_hat, p_hat.new_zeros(p_hat.size()))
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
Пример #7
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including training data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        ),
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        dev_sampler = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=dev_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"]).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Пример #8
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including trainning data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('skip DEBUG/INFO messages')

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"]["aux_context_window"]
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
    }

    # get data loader
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        aux_context_window=config["generator_params"]["aux_context_window"],
    )
    data_loader = {
        "train":
        DataLoader(dataset=dataset["train"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
        "dev":
        DataLoader(dataset=dataset["dev"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
    }

    # define models and optimizers
    model = {
        "generator":
        ParallelWaveGANGenerator(**config["generator_params"]).to(device),
        "discriminator":
        ParallelWaveGANDiscriminator(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
Пример #9
0
def main():
    """The main function that runs training process."""
    # initialize the argument parser
    parser = argparse.ArgumentParser(description="Train Parallel WaveGAN.") # just a description of the job that the parser is used to support.

    # Add arguments to the parser
        #first is name of the argument
        #default: The value produced if the argument is absent from the command line
        #type: The type to which the command-line argument should be converted
        #help: hint that appears when the user doesnot know what is this argument [-h]
        #required: Whether or not the command-line option may be omitted (optionals only)
        #nargs:The number of command-line arguments that should be consumed
            # "?" One argument will be consumed from the command line if possible, and produced as a single item. If no command-line argument is present, 
            # the value from default will be produced. 
            # Note that for optional arguments, there is an additional case - the option string is present but not followed by a command-line argument. In this case the value from const will be produced.

    parser.add_argument("--train-wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file for training. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--train-feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file for training. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--train-segments", default=None, type=str,
                        help="kaldi-style segments file for training.")

    parser.add_argument("--train-dumpdir", default=None, type=str,
                        help="directory including training data. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--dev-wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file for validation. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--dev-feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file for vaidation. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--dev-segments", default=None, type=str,
                        help="kaldi-style segments file for validation.")

    parser.add_argument("--dev-dumpdir", default=None, type=str,
                        help="directory including development data. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--outdir", type=str, required=True,
                        help="directory to save checkpoints.")

    parser.add_argument("--config", type=str, required=True,
                        help="yaml format configuration file.")

    parser.add_argument("--pretrain", default="", type=str, nargs="?",
                        help="checkpoint file path to load pretrained params. (default=\"\")")

    parser.add_argument("--resume", default="", type=str, nargs="?",
                        help="checkpoint file path to resume training. (default=\"\")")

    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")

    parser.add_argument("--rank", "--local_rank", default=0, type=int,
                        help="rank for distributed training. no need to explictly specify.")

    # parse all the input arguments 
    args = parser.parse_args()
    args.distributed = False

    if not torch.cuda.is_available(): #if gpu is not available 
        device = torch.device("cpu") #train on cpu
    else: #GPU
        device = torch.device("cuda")#train on gpu
        torch.backends.cudnn.benchmark = True # effective when using fixed size inputs (no conditional layers or layers inside loops),benchmark mode in cudnn,faster runtime
        torch.cuda.set_device(args.rank) # sets the default GPU for distributed training
        if "WORLD_SIZE" in os.environ:#determine max number of parallel processes (distributed)
            args.world_size = int(os.environ["WORLD_SIZE"]) #get the world size from the os
            args.distributed = args.world_size > 1 #set distributed if woldsize > 1 
        if args.distributed: 
            torch.distributed.init_process_group(backend="nccl", init_method="env://") #Use the NCCL backend for distributed GPU training (Rule of thumb)
                #NCCL:since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training

    # suppress logging for distributed training
    if args.rank != 0: #if process is not p0
        sys.stdout = open(os.devnull, "w")#DEVNULL is Special value that can be used as the stdin, stdout or stderr argument to

    # set logger
    if args.verbose > 1: #if level of logging is heigher then 1
        logging.basicConfig( #configure the logging
            level=logging.DEBUG, stream=sys.stdout, #heigh logging level,detailed information, typically of interest only when diagnosing problems.
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") #format includes Time,module,line#,level,and message.
    elif args.verbose > 0:#if level of logging is between 0,1
        logging.basicConfig(#configure the logging
            level=logging.INFO, stream=sys.stdout,#moderate logging level,Confirmation that things are working as expected.
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message.
    else:#if level of logging is 0
        logging.basicConfig(#configure the logging
            level=logging.WARN, stream=sys.stdout,#low logging level,An indication that something unexpected happened, or indicative of some problem in the near future (e.g. ‘disk space low’).
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message.
        logging.warning("Skip DEBUG/INFO messages")#tell the user that he will skip logging DEBUG/INFO messages by choosing this level.

    # check directory existence
    if not os.path.exists(args.outdir): #directory to save checkpoints
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \ 
            (args.train_feats_scp is None and args.train_dumpdir is None):
            # if the user chooses both training data files (examples) or
            # the user doesnot choose any training data file
        raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") #raise an error to tell the user to choose one training file
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \
            (args.dev_feats_scp is None and args.dev_dumpdir is None):
            # if the user chooses both validatation data files (examples) or
            # the user doesnot choose any validatation data file
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") #raise an error to tell the user to choose one validation data file

    # load config
    with open(args.config) as f:#open configuration file (yaml format)
        config = yaml.load(f, Loader=yaml.Loader) #load configuration file (yaml format to python object)
    # update config
    config.update(vars(args))#update arguments in configuration file
    config["version"] = parallel_wavegan.__version__  # add parallel wavegan version info
    # save config
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:#open outdir/config.yml
        yaml.dump(config, f, Dumper=yaml.Dumper) #dump function accepts a Python object and produces a YAML document.
    # add config info to the high level logger.
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:#if configuration tells to remove short samples from training.
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)#th of length = floor(batch_max_steps/hop_size) + 2 * (generator_params.aux_context_window)
    else:
        mel_length_threshold = None # No th.
    if args.train_wav_scp is None or args.dev_wav_scp is None: #if at least one of training or evaluating datasets = None
        if config["format"] == "hdf5":# format of data = hdf5
            audio_query, mel_query = "*.h5", "*.h5" # audio and text queries = "...".h5
            #lambda example:
            #x = lambda a, b: a * b
            #x(5, 6)-->x(a=5,b=6)=a*b=5*6=30
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # The function to load data,NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # The function to load data,NOQA
        elif config["format"] == "npy":# format of data = npy
            audio_query, mel_query = "*-wave.npy", "*-feats.npy" #audio query = "..."-wave.npy and text query = "..."-feats.h5
            audio_load_fn = np.load#The function to load data.
            mel_load_fn = np.load#The function to load data.
        else:#if any other data format
            raise ValueError("support only hdf5 or npy format.") #raise error to tell the user the data format is not supported.

    if args.train_dumpdir is not None: # if training ds is not None
        train_dataset = AudioMelDataset( # define the training dataset
            root_dir=args.train_dumpdir,#the directory of ds.
            audio_query=audio_query,#audio query according to format above.
            mel_query=mel_query,#mel query according to format above.
            audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above.
            mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above.
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility.
        )
    else:# if training ds is None
        train_dataset = AudioMelSCPDataset(# define the training dataset
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments, #segments of dataset
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.") # add length of trainning data set to the logger.
    if args.dev_dumpdir is not None: #if evaluating ds is not None
        dev_dataset = AudioMelDataset( # define the evaluating dataset
            root_dir=args.dev_dumpdir,#the directory of ds.
            audio_query=audio_query,#audio query according to format above.
            mel_query=mel_query,#mel query according to format above.
            audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above.
            mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above.
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:# if evaluating ds is None
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,#segments of dataset
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.") # add length of evaluating data set to the logger.
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    } #define the whole dataset used which is divided into training and evaluating datasets
    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get("aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        dev_sampler = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train": DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "dev": DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=dev_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator": generator_class(
            **config["generator_params"]).to(device),
        "discriminator": discriminator_class(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft": MultiResolutionSTFTLoss(
            **config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    optimizer = {
        "generator": RAdam(
            model["generator"].parameters(),
            **config["generator_optimizer_params"]),
        "discriminator": RAdam(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator": torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator": torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.")
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Пример #10
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(description=(
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    ))
    parser.add_argument(
        "--train-wav-scp",
        default=None,
        type=str,
        help=("kaldi-style wav.scp file for training. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--train-feats-scp",
        default=None,
        type=str,
        help=("kaldi-style feats.scp file for training. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--train-segments",
        default=None,
        type=str,
        help="kaldi-style segments file for training.",
    )
    parser.add_argument(
        "--train-dumpdir",
        default=None,
        type=str,
        help=("directory including training data. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--dev-wav-scp",
        default=None,
        type=str,
        help=("kaldi-style wav.scp file for validation. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--dev-feats-scp",
        default=None,
        type=str,
        help=("kaldi-style feats.scp file for vaidation. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--dev-segments",
        default=None,
        type=str,
        help="kaldi-style segments file for validation.",
    )
    parser.add_argument(
        "--dev-dumpdir",
        default=None,
        type=str,
        help=("directory including development data. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--outdir",
        type=str,
        required=True,
        help="directory to save checkpoints.",
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="yaml format configuration file.",
    )
    parser.add_argument(
        "--pretrain",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to load pretrained params. (default="")',
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to resume training. (default="")',
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.",
    )
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir
            is not None) or (args.train_feats_scp is None
                             and args.train_dumpdir is None):
        raise ValueError(
            "Please specify either --train-dumpdir or --train-*-scp.")
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or (
            args.dev_feats_scp is None and args.dev_dumpdir is None):
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.")

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config[
            "hop_size"] + 2 * config["generator_params"].get(
                "aux_context_window", 0)
    else:
        mel_length_threshold = None
    if args.train_wav_scp is None or args.dev_wav_scp is None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query = "*-wave.npy", "*-feats.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
    if args.train_dumpdir is not None:
        train_dataset = AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        train_dataset = AudioMelSCPDataset(
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.")
    if args.dev_dumpdir is not None:
        dev_dataset = AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.")
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get("generator_type",
                                   "ParallelWaveGANGenerator")
        in ["ParallelWaveGANGenerator"],
    )
    sampler = {"train": None, "dev": None}
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler

        sampler["train"] = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        sampler["dev"] = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["train"],
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["dev"],
            pin_memory=config["pin_memory"],
        ),
    }

    # define models
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"], ).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"], ).to(device),
    }

    # define criterions
    criterion = {
        "gen_adv":
        GeneratorAdversarialLoss(
            # keep compatibility
            **config.get("generator_adv_loss_params", {})).to(device),
        "dis_adv":
        DiscriminatorAdversarialLoss(
            # keep compatibility
            **config.get("discriminator_adv_loss_params", {})).to(device),
    }
    if config.get("use_stft_loss", True):  # keep compatibility
        config["use_stft_loss"] = True
        criterion["stft"] = MultiResolutionSTFTLoss(
            **config["stft_loss_params"], ).to(device)
    if config.get("use_subband_stft_loss", False):  # keep compatibility
        assert config["generator_params"]["out_channels"] > 1
        criterion["sub_stft"] = MultiResolutionSTFTLoss(
            **config["subband_stft_loss_params"], ).to(device)
    else:
        config["use_subband_stft_loss"] = False
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["feat_match"] = FeatureMatchLoss(
            # keep compatibility
            **config.get("feat_match_loss_params", {}), ).to(device)
    else:
        config["use_feat_match_loss"] = False
    if config.get("use_mel_loss", False):  # keep compatibility
        if config.get("mel_loss_params", None) is None:
            criterion["mel"] = MelSpectrogramLoss(
                fs=config["sampling_rate"],
                fft_size=config["fft_size"],
                hop_size=config["hop_size"],
                win_length=config["win_length"],
                window=config["window"],
                num_mels=config["num_mels"],
                fmin=config["fmin"],
                fmax=config["fmax"],
            ).to(device)
        else:
            criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],
                                                  ).to(device)
    else:
        config["use_mel_loss"] = False

    # define special module for subband processing
    if config["generator_params"]["out_channels"] > 1:
        criterion["pqmf"] = PQMF(
            subbands=config["generator_params"]["out_channels"],
            # keep compatibility
            **config.get("pqmf_params", {}),
        ).to(device)

    # define optimizers and schedulers
    generator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator":
        generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator":
        discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator":
        generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator":
        discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])

    # show settings
    logging.info(model["generator"])
    logging.info(model["discriminator"])
    logging.info(optimizer["generator"])
    logging.info(optimizer["discriminator"])
    logging.info(scheduler["generator"])
    logging.info(scheduler["discriminator"])
    for criterion_ in criterion.values():
        logging.info(criterion_)

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        sampler=sampler,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Пример #11
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument('--input_training_file', default='filelists/train.txt')
    parser.add_argument('--input_validation_file', default='filelists/val.txt')
    parser.add_argument('--checkpoint_path', default='parallel_wavegan/outdir')
    parser.add_argument(
        '--config',
        default='parallel_wavegan/config/multi_band_melgan.v2.yaml')
    parser.add_argument('--checkpoint_interval', default=0, type=int)
    parser.add_argument('--max_steps', default=1000000, type=int)
    parser.add_argument('--n_models_to_keep', default=1, type=int)
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    if config["checkpoint_interval"] > 0:
        config["save_interval_steps"] = config["checkpoint_interval"]
    if config["max_steps"] and config["max_steps"] > 0:
        config["train_max_steps"] = config["max_steps"]
    with open(os.path.join(args.checkpoint_path, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    data_loader = {}
    data_loader["train"], sampler = from_path(config["input_training_file"],
                                              config)
    data_loader["dev"] = from_path(config["input_validation_file"], config)[0]

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"]).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"]).to(device),
    }
    criterion = {# reconstruction loss
        "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse":  torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    if config["generator_params"]["out_channels"] > 1:
        criterion["pqmf"] = PQMF(
            subbands=config["generator_params"]["out_channels"],
            # keep compatibility
            **config.get("pqmf_params", {})).to(device)
    if config.get("use_subband_stft_loss", False):  # keep compatibility
        assert config["generator_params"]["out_channels"] > 1
        criterion["sub_stft"] = MultiResolutionSTFTLoss(
            **config["subband_stft_loss_params"]).to(device)
    generator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator":
        generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator":
        discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator":
        generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator":
        discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        sampler=sampler,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    trainer.load_checkpoint()

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        logging.info(f"Saving checkpoint in 10...")
        for i in range(9, -1, -1):
            print(f"{i}...")
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps:08}steps.pkl"))
        logging.info(
            f"Successfully saved checkpoint @ {trainer.steps:08}steps.")