예제 #1
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description=
        "Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
    )
    parser.add_argument(
        "--rootdir",
        type=str,
        required=True,
        help="directory including feature files to be normalized.")
    parser.add_argument("--dumpdir",
                        type=str,
                        required=True,
                        help="directory to dump normalized feature files.")
    parser.add_argument("--stats",
                        type=str,
                        required=True,
                        help="statistics file.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument("--n_jobs",
                        type=int,
                        default=16,
                        help="number of parallel jobs. (default=16)")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('Skip DEBUG/INFO messages')

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check directory existence or mkdir new one
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = AudioMelDataset(
        root_dir=args.rootdir,
        audio_query=audio_query,
        mel_query=mel_query,
        audio_load_fn=audio_load_fn,
        mel_load_fn=mel_load_fn,
        return_filename=True,
    )
    logging.info(f"The number of files = {len(dataset)}.")

    # restore scaler
    scaler = StandardScaler()
    if config["format"] == "hdf5":
        scaler.mean_ = read_hdf5(args.stats, "mean")
        scaler.scale_ = read_hdf5(args.stats, "scale")
    elif config["format"] == "npy":
        scaler.mean_ = np.load(args.stats)[0]
        scaler.scale_ = np.load(args.stats)[1]
    else:
        raise ValueError("support only hdf5 or npy format.")

    def _process_single_file(data):
        # parse inputs for each audio
        audio_name, mel_name, audio, mel = data

        # normalize
        """Scale features of X according to feature_range.
        mel *= self.scale_
        mel += self.min_ """
        mel = scaler.transform(mel)

        # save
        if config["format"] == "hdf5":
            write_hdf5(
                os.path.join(args.dumpdir, f"{os.path.basename(audio_name)}"),
                "wave", audio.astype(np.float32))
            write_hdf5(
                os.path.join(args.dumpdir, f"{os.path.basename(mel_name)}"),
                "feats", mel.astype(np.float32))
        elif config["format"] == "npy":
            np.save(os.path.join(args.dumpdir,
                                 f"{os.path.basename(audio_name)}"),
                    audio.astype(np.float32),
                    allow_pickle=False)
            np.save(os.path.join(args.dumpdir,
                                 f"{os.path.basename(mel_name)}"),
                    mel.astype(np.float32),
                    allow_pickle=False)
        else:
            raise ValueError("support only hdf5 or npy format.")

    # process in parallel
    """delayed => Decorator used to capture the arguments of a function."""
    Parallel(n_jobs=args.n_jobs, verbose=args.verbose)(
        [delayed(_process_single_file)(data) for data in tqdm(dataset)])
예제 #2
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including training data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        ),
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        dev_sampler = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=dev_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"]).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
예제 #3
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument("--train-dumpdir",
                        type=str,
                        required=True,
                        help="directory including trainning data.")
    parser.add_argument("--dev-dumpdir",
                        type=str,
                        required=True,
                        help="directory including development data.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('skip DEBUG/INFO messages')

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"]["aux_context_window"]
    else:
        mel_length_threshold = None
    if config["format"] == "hdf5":
        audio_query, mel_query = "*.h5", "*.h5"
        audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
        mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
    elif config["format"] == "npy":
        audio_query, mel_query = "*-wave.npy", "*-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("support only hdf5 or npy format.")
    dataset = {
        "train":
        AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
        "dev":
        AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibilty
        ),
    }

    # get data loader
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        aux_context_window=config["generator_params"]["aux_context_window"],
    )
    data_loader = {
        "train":
        DataLoader(dataset=dataset["train"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
        "dev":
        DataLoader(dataset=dataset["dev"],
                   shuffle=True,
                   collate_fn=collater,
                   batch_size=config["batch_size"],
                   num_workers=config["num_workers"],
                   pin_memory=config["pin_memory"]),
    }

    # define models and optimizers
    model = {
        "generator":
        ParallelWaveGANGenerator(**config["generator_params"]).to(device),
        "discriminator":
        ParallelWaveGANDiscriminator(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    optimizer = {
        "generator":
        RAdam(model["generator"].parameters(),
              **config["generator_optimizer_params"]),
        "discriminator":
        RAdam(model["discriminator"].parameters(),
              **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator":
        torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"successfully saved checkpoint @ {trainer.steps}steps.")
예제 #4
0
def main():
    """The main function that runs training process."""
    # initialize the argument parser
    parser = argparse.ArgumentParser(description="Train Parallel WaveGAN.") # just a description of the job that the parser is used to support.

    # Add arguments to the parser
        #first is name of the argument
        #default: The value produced if the argument is absent from the command line
        #type: The type to which the command-line argument should be converted
        #help: hint that appears when the user doesnot know what is this argument [-h]
        #required: Whether or not the command-line option may be omitted (optionals only)
        #nargs:The number of command-line arguments that should be consumed
            # "?" One argument will be consumed from the command line if possible, and produced as a single item. If no command-line argument is present, 
            # the value from default will be produced. 
            # Note that for optional arguments, there is an additional case - the option string is present but not followed by a command-line argument. In this case the value from const will be produced.

    parser.add_argument("--train-wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file for training. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--train-feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file for training. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--train-segments", default=None, type=str,
                        help="kaldi-style segments file for training.")

    parser.add_argument("--train-dumpdir", default=None, type=str,
                        help="directory including training data. "
                             "you need to specify either train-*-scp or train-dumpdir.")

    parser.add_argument("--dev-wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file for validation. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--dev-feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file for vaidation. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--dev-segments", default=None, type=str,
                        help="kaldi-style segments file for validation.")

    parser.add_argument("--dev-dumpdir", default=None, type=str,
                        help="directory including development data. "
                             "you need to specify either dev-*-scp or dev-dumpdir.")

    parser.add_argument("--outdir", type=str, required=True,
                        help="directory to save checkpoints.")

    parser.add_argument("--config", type=str, required=True,
                        help="yaml format configuration file.")

    parser.add_argument("--pretrain", default="", type=str, nargs="?",
                        help="checkpoint file path to load pretrained params. (default=\"\")")

    parser.add_argument("--resume", default="", type=str, nargs="?",
                        help="checkpoint file path to resume training. (default=\"\")")

    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")

    parser.add_argument("--rank", "--local_rank", default=0, type=int,
                        help="rank for distributed training. no need to explictly specify.")

    # parse all the input arguments 
    args = parser.parse_args()
    args.distributed = False

    if not torch.cuda.is_available(): #if gpu is not available 
        device = torch.device("cpu") #train on cpu
    else: #GPU
        device = torch.device("cuda")#train on gpu
        torch.backends.cudnn.benchmark = True # effective when using fixed size inputs (no conditional layers or layers inside loops),benchmark mode in cudnn,faster runtime
        torch.cuda.set_device(args.rank) # sets the default GPU for distributed training
        if "WORLD_SIZE" in os.environ:#determine max number of parallel processes (distributed)
            args.world_size = int(os.environ["WORLD_SIZE"]) #get the world size from the os
            args.distributed = args.world_size > 1 #set distributed if woldsize > 1 
        if args.distributed: 
            torch.distributed.init_process_group(backend="nccl", init_method="env://") #Use the NCCL backend for distributed GPU training (Rule of thumb)
                #NCCL:since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training

    # suppress logging for distributed training
    if args.rank != 0: #if process is not p0
        sys.stdout = open(os.devnull, "w")#DEVNULL is Special value that can be used as the stdin, stdout or stderr argument to

    # set logger
    if args.verbose > 1: #if level of logging is heigher then 1
        logging.basicConfig( #configure the logging
            level=logging.DEBUG, stream=sys.stdout, #heigh logging level,detailed information, typically of interest only when diagnosing problems.
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") #format includes Time,module,line#,level,and message.
    elif args.verbose > 0:#if level of logging is between 0,1
        logging.basicConfig(#configure the logging
            level=logging.INFO, stream=sys.stdout,#moderate logging level,Confirmation that things are working as expected.
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message.
    else:#if level of logging is 0
        logging.basicConfig(#configure the logging
            level=logging.WARN, stream=sys.stdout,#low logging level,An indication that something unexpected happened, or indicative of some problem in the near future (e.g. ‘disk space low’).
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")#format includes Time,module,line#,level,and message.
        logging.warning("Skip DEBUG/INFO messages")#tell the user that he will skip logging DEBUG/INFO messages by choosing this level.

    # check directory existence
    if not os.path.exists(args.outdir): #directory to save checkpoints
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \ 
            (args.train_feats_scp is None and args.train_dumpdir is None):
            # if the user chooses both training data files (examples) or
            # the user doesnot choose any training data file
        raise ValueError("Please specify either --train-dumpdir or --train-*-scp.") #raise an error to tell the user to choose one training file
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \
            (args.dev_feats_scp is None and args.dev_dumpdir is None):
            # if the user chooses both validatation data files (examples) or
            # the user doesnot choose any validatation data file
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.") #raise an error to tell the user to choose one validation data file

    # load config
    with open(args.config) as f:#open configuration file (yaml format)
        config = yaml.load(f, Loader=yaml.Loader) #load configuration file (yaml format to python object)
    # update config
    config.update(vars(args))#update arguments in configuration file
    config["version"] = parallel_wavegan.__version__  # add parallel wavegan version info
    # save config
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:#open outdir/config.yml
        yaml.dump(config, f, Dumper=yaml.Dumper) #dump function accepts a Python object and produces a YAML document.
    # add config info to the high level logger.
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:#if configuration tells to remove short samples from training.
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)#th of length = floor(batch_max_steps/hop_size) + 2 * (generator_params.aux_context_window)
    else:
        mel_length_threshold = None # No th.
    if args.train_wav_scp is None or args.dev_wav_scp is None: #if at least one of training or evaluating datasets = None
        if config["format"] == "hdf5":# format of data = hdf5
            audio_query, mel_query = "*.h5", "*.h5" # audio and text queries = "...".h5
            #lambda example:
            #x = lambda a, b: a * b
            #x(5, 6)-->x(a=5,b=6)=a*b=5*6=30
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # The function to load data,NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # The function to load data,NOQA
        elif config["format"] == "npy":# format of data = npy
            audio_query, mel_query = "*-wave.npy", "*-feats.npy" #audio query = "..."-wave.npy and text query = "..."-feats.h5
            audio_load_fn = np.load#The function to load data.
            mel_load_fn = np.load#The function to load data.
        else:#if any other data format
            raise ValueError("support only hdf5 or npy format.") #raise error to tell the user the data format is not supported.

    if args.train_dumpdir is not None: # if training ds is not None
        train_dataset = AudioMelDataset( # define the training dataset
            root_dir=args.train_dumpdir,#the directory of ds.
            audio_query=audio_query,#audio query according to format above.
            mel_query=mel_query,#mel query according to format above.
            audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above.
            mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above.
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility.
        )
    else:# if training ds is None
        train_dataset = AudioMelSCPDataset(# define the training dataset
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments, #segments of dataset
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.") # add length of trainning data set to the logger.
    if args.dev_dumpdir is not None: #if evaluating ds is not None
        dev_dataset = AudioMelDataset( # define the evaluating dataset
            root_dir=args.dev_dumpdir,#the directory of ds.
            audio_query=audio_query,#audio query according to format above.
            mel_query=mel_query,#mel query according to format above.
            audio_load_fn=audio_load_fn,#load the function that loads the audio data according to format above.
            mel_load_fn=mel_load_fn,#load the function that loads the mel data according to format above.
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:# if evaluating ds is None
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,#segments of dataset
            mel_length_threshold=mel_length_threshold,#th to remove short samples -calculated above-.
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.") # add length of evaluating data set to the logger.
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    } #define the whole dataset used which is divided into training and evaluating datasets
    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get("aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
    )
    train_sampler, dev_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        dev_sampler = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train": DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "dev": DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=dev_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator": generator_class(
            **config["generator_params"]).to(device),
        "discriminator": discriminator_class(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft": MultiResolutionSTFTLoss(
            **config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    optimizer = {
        "generator": RAdam(
            model["generator"].parameters(),
            **config["generator_optimizer_params"]),
        "discriminator": RAdam(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"]),
    }
    scheduler = {
        "generator": torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"]),
        "discriminator": torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"]),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.")
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
예제 #5
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py).")
    parser.add_argument("--rootdir", default=None, type=str,
                        help="directory including feature files to be normalized. "
                             "you need to specify either *-scp or rootdir.")
    parser.add_argument("--wav-scp", default=None, type=str,
                        help="kaldi-style wav.scp file. "
                             "you need to specify either *-scp or rootdir.")
    parser.add_argument("--feats-scp", default=None, type=str,
                        help="kaldi-style feats.scp file. "
                             "you need to specify either *-scp or rootdir.")
    parser.add_argument("--segments", default=None, type=str,
                        help="kaldi-style segments file.")
    parser.add_argument("--dumpdir", type=str, required=True,
                        help="directory to dump normalized feature files.")
    parser.add_argument("--stats", type=str, required=True,
                        help="statistics file.")
    parser.add_argument("--skip-wav-copy", default=False, action="store_true",
                        help="whether to skip the copy of wav files.")
    parser.add_argument("--config", type=str, required=True,
                        help="yaml format configuration file.")
    parser.add_argument("--ftype", default='mel', type=str,
                        help="feature type")
    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")

    # runtime mode
    args = parser.parse_args()

    # interactive mode
    # args = argparse.ArgumentParser()
    # args.wav_scp = None
    # args.feats_scp = None
    # args.segment = None
    # args.dumpdir = ""
    # args.skip_wav_copy = True
    # args.config = 'egs/so_emo_female/voc1/conf/multi_band_melgan.v2.yaml'
    # args.ftype = 'spec'
    # args.verbose = 1

    # args.rootdir = '/data/evs/VCTK/VCTK-wgan/spec'
    # args.stats = '/data/evs/VCTK/VCTK-wgan/spec/mel_mean_std.npy'

    # args.rootdir = '/data/evs/Arctic/spec'
    # args.stats = '/data/evs/Arctic/spec/spec_mean_std.npy'

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning('Skip DEBUG/INFO messages')

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.rootdir is not None) or \
            (args.feats_scp is None and args.rootdir is None):
        raise ValueError("Please specify either --rootdir or --feats-scp.")

    # check directory existence
    if args.dumpdir != "":
        if not os.path.exists(args.dumpdir):
            os.makedirs(args.dumpdir, exist_ok=True)

    # get dataset
    if args.rootdir is not None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query, spc_query = "*.wav.npy", "*.mel.npy", "*.spec.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
            spc_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        if not args.skip_wav_copy:
            dataset = AudioMelDataset(
                root_dir=args.rootdir,
                audio_query=audio_query,
                mel_query=mel_query,
                audio_load_fn=audio_load_fn,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
        else:
            dataset1 = MelDatasetNew(
                root_dir=args.rootdir,
                mel_query=mel_query,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
            dataset2 = SpcDatasetNew(
                root_dir=args.rootdir,
                spc_query=spc_query,
                spc_load_fn=spc_load_fn,
                return_utt_id=True,
            )
    else:
        if not args.skip_wav_copy:
            dataset = AudioMelSCPDataset(
                wav_scp=args.wav_scp,
                feats_scp=args.feats_scp,
                segments=args.segments,
                return_utt_id=True,
            )
        else:
            dataset = MelSCPDataset(
                feats_scp=args.feats_scp,
                return_utt_id=True,
            )
    logging.info(f"The number of files in mel dataset = {len(dataset1)}.")
    logging.info(f"The number of files in spc dataset = {len(dataset2)}.")

    # restore scaler
    scaler = StandardScaler()
    if config["format"] == "hdf5":
        scaler.mean_ = read_hdf5(args.stats, "mean")
        scaler.scale_ = read_hdf5(args.stats, "scale")
    elif config["format"] == "npy":
        scaler.mean_ = np.load(args.stats)[0]
        scaler.scale_ = np.load(args.stats)[1]
    else:
        raise ValueError("support only hdf5 or npy format.")
    # from version 0.23.0, this information is needed
    scaler.n_features_in_ = scaler.mean_.shape[0]

    # process each file
    if args.ftype == 'mel':
      dataset = dataset1
    elif args.ftype == 'spec':
      dataset = dataset2

    for items in tqdm(dataset):
        if not args.skip_wav_copy:
            utt_id, audio, feat = items
        else:
            utt_id, feat, feat_file = items

        # normalize
        feat = scaler.transform(feat)
        # feat = (feat - scaler.mean_) / scaler.scale_ # this is identical to scaler.transform(feat)

        # save
        if config["format"] == "hdf5":
            write_hdf5(os.path.join(args.dumpdir, f"{utt_id}.h5"),
                       "feats", feat.astype(np.float32))
            if not args.skip_wav_copy:
                write_hdf5(os.path.join(args.dumpdir, f"{utt_id}.h5"),
                           "wave", audio.astype(np.float32))
        elif config["format"] == "npy":
            if args.dumpdir == "":
                feat_file = feat_file.replace('.npy', '')

                np.save((feat_file + "-norm.npy"),
                        feat.astype(np.float32), allow_pickle=False)
                if not args.skip_wav_copy:
                    print("Please include --skip_wav_copy in arguments")

            else:
                np.save(os.path.join(args.dumpdir, f"{utt_id}.npy"),
                        feat.astype(np.float32), allow_pickle=False)
                if not args.skip_wav_copy:
                    np.save(os.path.join(args.dumpdir, f"{utt_id}.wav.npy"),
                            audio.astype(np.float32), allow_pickle=False)
        else:
            raise ValueError("support only hdf5 or npy format.")
예제 #6
0
def main():
    """Run preprocessing process."""
    parser = argparse.ArgumentParser(
        description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
    )
    parser.add_argument(
        "--rootdir",
        default=None,
        type=str,
        help="directory including feature files to be normalized. "
        "you need to specify either *-scp or rootdir.",
    )
    parser.add_argument(
        "--wav-scp",
        default=None,
        type=str,
        help="kaldi-style wav.scp file. "
        "you need to specify either *-scp or rootdir.",
    )
    parser.add_argument(
        "--feats-scp",
        default=None,
        type=str,
        help="kaldi-style feats.scp file. "
        "you need to specify either *-scp or rootdir.",
    )
    parser.add_argument(
        "--segments",
        default=None,
        type=str,
        help="kaldi-style segments file.",
    )
    parser.add_argument(
        "--dumpdir",
        type=str,
        required=True,
        help="directory to dump normalized feature files.",
    )
    parser.add_argument(
        "--stats",
        type=str,
        required=True,
        help="statistics file.",
    )
    parser.add_argument(
        "--skip-wav-copy",
        default=False,
        action="store_true",
        help="whether to skip the copy of wav files.",
    )
    parser.add_argument(
        "--config", type=str, required=True, help="yaml format configuration file."
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.rootdir is not None) or (
        args.feats_scp is None and args.rootdir is None
    ):
        raise ValueError("Please specify either --rootdir or --feats-scp.")

    # check directory existence
    if not os.path.exists(args.dumpdir):
        os.makedirs(args.dumpdir)

    # get dataset
    if args.rootdir is not None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query = "*-wave.npy", "*-feats.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        if not args.skip_wav_copy:
            dataset = AudioMelDataset(
                root_dir=args.rootdir,
                audio_query=audio_query,
                mel_query=mel_query,
                audio_load_fn=audio_load_fn,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
        else:
            dataset = MelDataset(
                root_dir=args.rootdir,
                mel_query=mel_query,
                mel_load_fn=mel_load_fn,
                return_utt_id=True,
            )
    else:
        if not args.skip_wav_copy:
            dataset = AudioMelSCPDataset(
                wav_scp=args.wav_scp,
                feats_scp=args.feats_scp,
                segments=args.segments,
                return_utt_id=True,
            )
        else:
            dataset = MelSCPDataset(
                feats_scp=args.feats_scp,
                return_utt_id=True,
            )
    logging.info(f"The number of files = {len(dataset)}.")

    # restore scaler
    scaler = StandardScaler()
    if config["format"] == "hdf5":
        scaler.mean_ = read_hdf5(args.stats, "mean")
        scaler.scale_ = read_hdf5(args.stats, "scale")
    elif config["format"] == "npy":
        scaler.mean_ = np.load(args.stats)[0]
        scaler.scale_ = np.load(args.stats)[1]
    else:
        raise ValueError("support only hdf5 or npy format.")
    # from version 0.23.0, this information is needed
    scaler.n_features_in_ = scaler.mean_.shape[0]

    # process each file
    for items in tqdm(dataset):
        if not args.skip_wav_copy:
            utt_id, audio, mel = items
        else:
            utt_id, mel = items

        # normalize
        mel = scaler.transform(mel)

        # save
        if config["format"] == "hdf5":
            write_hdf5(
                os.path.join(args.dumpdir, f"{utt_id}.h5"),
                "feats",
                mel.astype(np.float32),
            )
            if not args.skip_wav_copy:
                write_hdf5(
                    os.path.join(args.dumpdir, f"{utt_id}.h5"),
                    "wave",
                    audio.astype(np.float32),
                )
        elif config["format"] == "npy":
            np.save(
                os.path.join(args.dumpdir, f"{utt_id}-feats.npy"),
                mel.astype(np.float32),
                allow_pickle=False,
            )
            if not args.skip_wav_copy:
                np.save(
                    os.path.join(args.dumpdir, f"{utt_id}-wave.npy"),
                    audio.astype(np.float32),
                    allow_pickle=False,
                )
        else:
            raise ValueError("support only hdf5 or npy format.")
예제 #7
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(description=(
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    ))
    parser.add_argument(
        "--train-wav-scp",
        default=None,
        type=str,
        help=("kaldi-style wav.scp file for training. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--train-feats-scp",
        default=None,
        type=str,
        help=("kaldi-style feats.scp file for training. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--train-segments",
        default=None,
        type=str,
        help="kaldi-style segments file for training.",
    )
    parser.add_argument(
        "--train-dumpdir",
        default=None,
        type=str,
        help=("directory including training data. "
              "you need to specify either train-*-scp or train-dumpdir."),
    )
    parser.add_argument(
        "--dev-wav-scp",
        default=None,
        type=str,
        help=("kaldi-style wav.scp file for validation. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--dev-feats-scp",
        default=None,
        type=str,
        help=("kaldi-style feats.scp file for vaidation. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--dev-segments",
        default=None,
        type=str,
        help="kaldi-style segments file for validation.",
    )
    parser.add_argument(
        "--dev-dumpdir",
        default=None,
        type=str,
        help=("directory including development data. "
              "you need to specify either dev-*-scp or dev-dumpdir."),
    )
    parser.add_argument(
        "--outdir",
        type=str,
        required=True,
        help="directory to save checkpoints.",
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="yaml format configuration file.",
    )
    parser.add_argument(
        "--pretrain",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to load pretrained params. (default="")',
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to resume training. (default="")',
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.",
    )
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir
            is not None) or (args.train_feats_scp is None
                             and args.train_dumpdir is None):
        raise ValueError(
            "Please specify either --train-dumpdir or --train-*-scp.")
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or (
            args.dev_feats_scp is None and args.dev_dumpdir is None):
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.")

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config[
            "hop_size"] + 2 * config["generator_params"].get(
                "aux_context_window", 0)
    else:
        mel_length_threshold = None
    if args.train_wav_scp is None or args.dev_wav_scp is None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            audio_load_fn = lambda x: read_hdf5(x, "wave")  # NOQA
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query = "*-wave.npy", "*-feats.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
    if args.train_dumpdir is not None:
        train_dataset = AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        train_dataset = AudioMelSCPDataset(
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.")
    if args.dev_dumpdir is not None:
        dev_dataset = AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.")
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get("generator_type",
                                   "ParallelWaveGANGenerator")
        in ["ParallelWaveGANGenerator"],
    )
    sampler = {"train": None, "dev": None}
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler

        sampler["train"] = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        sampler["dev"] = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["train"],
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["dev"],
            pin_memory=config["pin_memory"],
        ),
    }

    # define models
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"], ).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"], ).to(device),
    }

    # define criterions
    criterion = {
        "gen_adv":
        GeneratorAdversarialLoss(
            # keep compatibility
            **config.get("generator_adv_loss_params", {})).to(device),
        "dis_adv":
        DiscriminatorAdversarialLoss(
            # keep compatibility
            **config.get("discriminator_adv_loss_params", {})).to(device),
    }
    if config.get("use_stft_loss", True):  # keep compatibility
        config["use_stft_loss"] = True
        criterion["stft"] = MultiResolutionSTFTLoss(
            **config["stft_loss_params"], ).to(device)
    if config.get("use_subband_stft_loss", False):  # keep compatibility
        assert config["generator_params"]["out_channels"] > 1
        criterion["sub_stft"] = MultiResolutionSTFTLoss(
            **config["subband_stft_loss_params"], ).to(device)
    else:
        config["use_subband_stft_loss"] = False
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["feat_match"] = FeatureMatchLoss(
            # keep compatibility
            **config.get("feat_match_loss_params", {}), ).to(device)
    else:
        config["use_feat_match_loss"] = False
    if config.get("use_mel_loss", False):  # keep compatibility
        if config.get("mel_loss_params", None) is None:
            criterion["mel"] = MelSpectrogramLoss(
                fs=config["sampling_rate"],
                fft_size=config["fft_size"],
                hop_size=config["hop_size"],
                win_length=config["win_length"],
                window=config["window"],
                num_mels=config["num_mels"],
                fmin=config["fmin"],
                fmax=config["fmax"],
            ).to(device)
        else:
            criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],
                                                  ).to(device)
    else:
        config["use_mel_loss"] = False

    # define special module for subband processing
    if config["generator_params"]["out_channels"] > 1:
        criterion["pqmf"] = PQMF(
            subbands=config["generator_params"]["out_channels"],
            # keep compatibility
            **config.get("pqmf_params", {}),
        ).to(device)

    # define optimizers and schedulers
    generator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator":
        generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator":
        discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator":
        generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator":
        discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])

    # show settings
    logging.info(model["generator"])
    logging.info(model["discriminator"])
    logging.info(optimizer["generator"])
    logging.info(optimizer["discriminator"])
    logging.info(scheduler["generator"])
    logging.info(scheduler["discriminator"])
    for criterion_ in criterion.values():
        logging.info(criterion_)

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        sampler=sampler,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    finally:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")