Пример #1
0
def world_speech_synthesis(queue, wav_list, config):
    """WORLD speech synthesis

    Args:
        queue (multiprocessing.Queue): the queue to store the file name of utterance
        wav_list (list): list of the wav files
        config (dict): feature extraction config

    """
    # define synthesizer
    synthesizer = Synthesizer(fs=config['sampling_rate'],
                              fftl=config['fft_size'],
                              shiftms=config['shiftms'])
    # synthesis
    for i, wav_name in enumerate(wav_list):
        logging.info("now processing %s (%d/%d)" %
                     (wav_name, i + 1, len(wav_list)))

        # load acoustic features
        feat_name = path_replace(wav_name,
                                 config['indir'],
                                 config['outdir'],
                                 extname=config['feature_format'])
        if check_hdf5(feat_name, "/world"):
            h = read_hdf5(feat_name, "/world")
        else:
            logging.error("%s is not existed." % (feat_name))
            sys.exit(1)
        if check_hdf5(feat_name, "/f0"):
            f0 = read_hdf5(feat_name, "/f0")
        else:
            uv = h[:, config['uv_dim_idx']].copy(order='C')
            f0 = h[:, config['f0_dim_idx']].copy(order='C')  # cont_f0_lpf
            fz_idx = np.where(uv == 0.0)
            f0[fz_idx] = 0.0
        if check_hdf5(feat_name, "/ap"):
            ap = read_hdf5(feat_name, "/ap")
        else:
            codeap = h[:, config['ap_dim_start']:config['ap_dim_end']].copy(
                order='C')
            ap = pyworld.decode_aperiodicity(codeap, config['sampling_rate'],
                                             config['fft_size'])
        mcep = h[:, config['mcep_dim_start']:config['mcep_dim_end']].copy(
            order='C')

        # waveform synthesis
        wav = synthesizer.synthesis(f0, mcep, ap, alpha=config['mcep_alpha'])
        wav = np.clip(np.int16(wav), -32768, 32767)

        # save restored wav
        restored_name = path_replace(wav_name, "wav", "world", extname="wav")
        wavfile.write(restored_name, config['sampling_rate'], wav)

    queue.put('Finish')
Пример #2
0
def calc_stats(file_list, config, shift=1):
    """Calcute statistics

    Args:
        file_list (list): File list.
        config (dict): Dictionary of config.
        shift (int): Shift of feature dimesion.

    """
    scaler = StandardScaler()

    # process over all of data
    for i, filename in enumerate(file_list):
        logging.info("now processing %s (%d/%d)" % (filename, i + 1, len(file_list)))
        feat = read_hdf5(filename, "/%s" % config['feat_type'])
        scaler.partial_fit(feat[:, shift:])

    dump(scaler, config['stats'])
Пример #3
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description="Train Quasi-Periodic Parallel WaveGAN (See detail in qppwg/bin/train.py).")
    parser.add_argument("--train_audio", required=True, type=str,
                        help="list of training wav files")
    parser.add_argument("--train_feat", required=True, type=str,
                        help="list of training feat files")
    parser.add_argument("--valid_audio", required=True, type=str,
                        help="list of validation wav files")
    parser.add_argument("--valid_feat", required=True, type=str,
                        help="list of validation feat files")
    parser.add_argument("--stats", required=True, type=str, 
                        help="hdf5 file including statistics")
    parser.add_argument("--outdir", required=True, type=str,
                        help="directory to save checkpoints.")
    parser.add_argument("--config", required=True, type=str,
                        help="yaml format configuration file.")
    parser.add_argument("--pretrain", default="", type=str, nargs="?",
                        help="checkpoint file path to load pretrained params. (default=\"\")")
    parser.add_argument("--resume", default="", type=str, nargs="?",
                        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument("--verbose", default=1, type=int,
                        help="logging level. higher is more logging. (default=1)")
    parser.add_argument("--rank", "--local_rank", default=0, type=int,
                        help="rank for distributed training. no need to explictly specify.")
    parser.add_argument("--seed", default=1, type=int, 
                        help="seed number")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        print("CPU")
        device = torch.device("cpu")
    else:
        print("GPU")
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl", init_method="env://")

    # fix seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG, stream=sys.stdout,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, stream=sys.stdout,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN, stream=sys.stdout,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = qppwg.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        feat_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)
    else:
        feat_length_threshold = None

    audio_load_fn = sf.read
    feat_load_fn = lambda x: read_hdf5(x, config.get("feat_type", "world"))

    train_dataset = AudioFeatDataset(
        stats=args.stats,
        audio_list=args.train_audio,
        audio_load_fn=audio_load_fn,
        feat_list=args.train_feat,
        feat_load_fn=feat_load_fn,
        feat_length_threshold=feat_length_threshold,
        allow_cache=config.get("allow_cache", False),
        hop_size=config["hop_size"],
        dense_factor=config.get("dense_factor", 4),
        f0_threshold=config.get("f0_threshold", 0),
        f0_cont=config.get("f0_cont", True),
        f0_dim_idx=config.get("f0_dim_idx", 1),
        uv_dim_idx=config.get("uv_dim_idx", 0),
        mean_path=config.get("mean_path", "/world/mean"),
        scale_path=config.get("scale_path", "/world/scale"),
        shift=config.get("stats_shift", 1),
    )
    logging.info(f"The number of training files = {len(train_dataset)}.")

    valid_dataset = AudioFeatDataset(
        stats=args.stats,
        audio_list=args.valid_audio,
        audio_load_fn=audio_load_fn,
        feat_list=args.valid_feat,
        feat_load_fn=feat_load_fn,
        feat_length_threshold=feat_length_threshold,
        allow_cache=config.get("allow_cache", False),
        hop_size=config["hop_size"],
        dense_factor=config.get("dense_factor", 4),
        f0_threshold=config.get("f0_threshold", 0),
        f0_cont=config.get("f0_cont", True),
        f0_dim_idx=config.get("f0_dim_idx", 0),
        uv_dim_idx=config.get("uv_dim_idx", 1),
        mean_path=config.get("mean_path", "/world/mean"),
        scale_path=config.get("scale_path", "/world/scale"),
        shift=config.get("stats_shift", 1),
    )
    logging.info(f"The number of validation files = {len(valid_dataset)}.")

    dataset = {
        "train": train_dataset,
        "valid": valid_dataset,
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get("aux_context_window", 0),
        # keep compatibility
        input_type=config.get("input_type", "noise"),
    )
    train_sampler, valid_sampler = None, None
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        train_sampler = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        valid_sampler = DistributedSampler(
            dataset=dataset["valid"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )

    data_loader = {
        "train": DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=train_sampler,
            pin_memory=config["pin_memory"],
        ),
        "valid": DataLoader(
            dataset=dataset["valid"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=valid_sampler,
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        qppwg.models,
        # keep compatibility
        config.get("generator_type", "QPPWGGenerator"),
    )
    discriminator_class = getattr(
        qppwg.models,
        # keep compatibility
        config.get("discriminator_type", "QPPWGDiscriminator"),
    )
    model = {
        "generator": generator_class(
            **config["generator_params"]).to(device),
        "discriminator": discriminator_class(
            **config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft": MultiResolutionSTFTLoss(
            **config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    generator_optimizer_class = getattr(
        qppwg.optimizers,
        # keep compatibility
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        qppwg.optimizers,
        # keep compatibility
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator": generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator": discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator": generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator": discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError("apex is not installed. please check https://github.com/NVIDIA/apex.")
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(model["discriminator"])
    logging.debug(model["generator"])
    logging.debug(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained/resume parameters from checkpoint
    if os.path.exists(args.resume):
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")
    elif os.path.exists(args.pretrain):
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")
    else:
        logging.info("Start a new training process.")
        
    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Пример #4
0
def main():
    """Run decoding process."""
    parser = argparse.ArgumentParser(
        description="Decode dumped features with trained Quasi-Periodic Parallel WaveGAN Generator "
                    "(See detail in qppwg/bin/decode.py).")
    parser.add_argument("--eval_feat", required=True, type=str,
                        help="list of evaluation aux feat files")
    parser.add_argument("--stats", required=True, type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--indir", required=True, type=str,
                        help="directory of input feature files")
    parser.add_argument("--outdir", type=str, required=True,
                        help="directory to output generated speech.")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument("--config", default=None, type=str,
                        help="yaml format configuration file. if not explicitly provided, "
                             "it will be searched in the checkpoint directory. (default=None)")
    parser.add_argument("--verbose", type=int, default=1,
                        help="logging level. higher is more logging. (default=1)")
    parser.add_argument("--seed", default=100, type=int,
                        help="seed number")
    parser.add_argument("--f0_factor", default=1.0, type=float,
                        help="f0 scaled factor")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # fix seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)

    # check directory existence
    if not os.path.isdir(os.path.dirname(args.outdir)):
        os.makedirs(os.path.dirname(args.outdir))

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # get dataset
    feat_load_fn = lambda x: read_hdf5(x, config.get("feat_type", "world"))
    f0_factor = args.f0_factor
    dataset = FeatDataset(
        stats=args.stats,
        feat_list=args.eval_feat,
        feat_load_fn=feat_load_fn,
        return_filename=True,
        hop_size=config["hop_size"],
        dense_factor=config.get("dense_factor", 4),
        f0_threshold=config.get("f0_threshold", 0),
        f0_cont=config.get("f0_cont", True),
        f0_dim_idx=config.get("f0_dim_idx", 0),
        uv_dim_idx=config.get("uv_dim_idx", 1),
        mean_path=config.get("mean_path", "/world/mean"),
        scale_path=config.get("scale_path", "/world/scale"),
        f0_factor=f0_factor,
        fs=config.get("sampling_rate", 22050),
        shift=config.get("stats_shift", 1),
    )
    logging.info(f"The number of features to be decoded = {len(dataset)}.")

    # setup
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model_class = getattr(
        qppwg.models,
        config.get("generator_type", "ParallelWaveGANGenerator"))
    model = model_class(**config["generator_params"])
    model.load_state_dict(
        torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
    logging.info(f"Loaded model parameters from {args.checkpoint}.")
    model.remove_weight_norm()
    model = model.eval().to(device)
    input_type = config.get("input_type", "noise")
    pad_fn = torch.nn.ReplicationPad1d(
        config["generator_params"].get("aux_context_window", 0))

    # start generation
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (feat_path, c, d) in enumerate(pbar, 1):
            # setup input
            x = ()
            if input_type == "noise":
                z = torch.randn(1, 1, len(c) * config["hop_size"]).to(device)
                x += (z,)
            else:
                raise NotImplementedError("Currently only 'noise' input is supported ")
            c = pad_fn(torch.FloatTensor(c).unsqueeze(0).transpose(2, 1)).to(device)
            d = torch.FloatTensor(d).view(1, 1, -1).to(device)
            x += (c, d,)

            # generate
            start = time.time()
            y = model(*x).view(-1).cpu().numpy()
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            feat_path = os.path.splitext(feat_path)[0]
            feat_path = feat_path.replace(args.indir, args.outdir)
            if f0_factor == 1.0:  # unchanged
                wav_filename = "%s.wav" % (feat_path)
            else:  # scaled f0
                wav_filename = "%s_f%.2f.wav" % (feat_path, f0_factor)
            if not os.path.exists(os.path.dirname(wav_filename)):
                os.makedirs(os.path.dirname(wav_filename))
            sf.write(wav_filename, y, config.get("sampling_rate", 22050), "PCM_16")

    # report average RTF
    logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
Пример #5
0
    def __init__(
        self,
        stats,
        audio_list,
        feat_list,
        audio_load_fn=sf.read,
        feat_load_fn=lambda x: read_hdf5(x, "world"),
        audio_length_threshold=None,
        feat_length_threshold=None,
        return_filename=False,
        allow_cache=False,
        hop_size=110,
        dense_factor=4,
        f0_threshold=0,
        f0_cont=True,
        f0_dim_idx=1,
        uv_dim_idx=0,
        mean_path="/world/mean",
        scale_path="/world/scale",
        shift=1,
    ):
        """Initialize dataset.

        Args:
            stats (str): Filename of the statistic hdf5 file.
            audio_list (str): Filename of the list of audio files.
            feat_list (str): Filename of the list of feature files.
            audio_load_fn (func): Function to load audio file.
            feat_load_fn (func): Function to load feature file.
            audio_length_threshold (int): Threshold to remove short audio files.
            feat_length_threshold (int): Threshold to remove short feature files.
            return_filename (bool): Whether to return the filename with arrays.
            allow_cache (bool): Whether to allow cache of the loaded files.
            hop_size (int): Hope size of acoustic feature
            dense_factor (int): Number of taps in one cycle.
            f0_threshold (float): Lower bound of pitch.
            f0_cont (bool): Whether to get dilated factor by continuous f0.
            f0_dim_idx (int): Dimension index of f0. (if set -1, all dilated factors will be 1)
            uv_dim_idx (int): Dimension index of U/V.
            mean_path (str): The data path (channel) of the mean in the statistic hdf5 file.
            scale_path (str): The data path (channel) of the scale in the statistic hdf5 file.
            shift (int): Shift of feature dimesion.

        """
        # load audio and feature files & check filename
        audio_files = read_txt(audio_list)
        feat_files = read_txt(feat_list)
        assert check_filename(audio_files, feat_files)

        # filter by threshold
        if audio_length_threshold is not None:
            audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
            idxs = [
                idx for idx in range(len(audio_files))
                if audio_lengths[idx] > audio_length_threshold
            ]
            if len(audio_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by audio length threshold "
                    f"({len(audio_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            feat_files = [feat_files[idx] for idx in idxs]
        if feat_length_threshold is not None:
            mel_lengths = [feat_load_fn(f).shape[0] for f in feat_files]
            idxs = [
                idx for idx in range(len(feat_files))
                if mel_lengths[idx] > feat_length_threshold
            ]
            if len(feat_files) != len(idxs):
                logging.warning(
                    f"Some files are filtered by mel length threshold "
                    f"({len(feat_files)} -> {len(idxs)}).")
            audio_files = [audio_files[idx] for idx in idxs]
            feat_files = [feat_files[idx] for idx in idxs]

        # assert the number of files
        assert len(audio_files) != 0, f"${audio_list} is empty."
        assert len(audio_files) == len(feat_files), \
            f"Number of audio and mel files are different ({len(audio_files)} vs {len(feat_files)})."

        self.audio_files = audio_files
        self.audio_load_fn = audio_load_fn
        self.feat_load_fn = feat_load_fn
        self.feat_files = feat_files
        self.return_filename = return_filename
        self.allow_cache = allow_cache
        self.hop_size = hop_size
        self.f0_threshold = f0_threshold
        self.dense_factor = dense_factor
        self.f0_cont = f0_cont
        self.f0_dim_idx = f0_dim_idx
        self.uv_dim_idx = uv_dim_idx
        self.shift = shift

        if allow_cache:
            # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
            self.manager = Manager()
            self.caches = self.manager.list()
            self.caches += [() for _ in range(len(audio_files))]

        # define feature pre-processing function
        scaler = load(stats)
        self.feat_transform = lambda x: scaler.transform(x)