Esempio n. 1
0
def test_pqmf(subbands):
    pqmf = PQMF(subbands)
    x = torch.randn(1, 1, subbands * 32)
    y = pqmf.analysis(x)
    assert y.shape[2] * subbands == x.shape[2]
    x_hat = pqmf.synthesis(y)
    assert x.shape[2] == x_hat.shape[2]
Esempio n. 2
0
def load_model(checkpoint, config=None):
    """Load trained model.

    Args:
        checkpoint (str): Checkpoint path.
        config (dict): Configuration dict.

    Return:
        torch.nn.Module: Model instance.

    """
    # load config if not provided
    if config is None:
        dirname = os.path.dirname(checkpoint)
        config = os.path.join(dirname, "config.yml")
        with open(config) as f:
            config = yaml.load(f, Loader=yaml.Loader)

    # lazy load for circular error
    import parallel_wavegan.models

    # get model and load parameters
    model_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator")
    )
    model = model_class(**config["generator_params"])
    model.load_state_dict(
        torch.load(checkpoint, map_location="cpu")["model"]["generator"]
    )

    # add pqmf if needed
    if config["generator_params"]["out_channels"] > 1:
        # lazy load for circular error
        from parallel_wavegan.layers import PQMF

        pqmf_params = {}
        if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"):
            # For compatibility, here we set default values in version <= 0.4.2
            pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0)
        model.pqmf = PQMF(
            subbands=config["generator_params"]["out_channels"],
            **config.get("pqmf_params", pqmf_params),
        )

    return model
Esempio n. 3
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument(
        "--train-wav-scp",
        default=None,
        type=str,
        help="kaldi-style wav.scp file for training. "
        "you need to specify either train-*-scp or train-dumpdir.")
    parser.add_argument(
        "--train-feats-scp",
        default=None,
        type=str,
        help="kaldi-style feats.scp file for training. "
        "you need to specify either train-*-scp or train-dumpdir.")
    parser.add_argument("--train-segments",
                        default=None,
                        type=str,
                        help="kaldi-style segments file for training.")
    parser.add_argument(
        "--train-dumpdir",
        default=None,
        type=str,
        help="directory including training data. "
        "you need to specify either train-*-scp or train-dumpdir.")
    parser.add_argument("--dev-wav-scp",
                        default=None,
                        type=str,
                        help="kaldi-style wav.scp file for validation. "
                        "you need to specify either dev-*-scp or dev-dumpdir.")
    parser.add_argument("--dev-feats-scp",
                        default=None,
                        type=str,
                        help="kaldi-style feats.scp file for vaidation. "
                        "you need to specify either dev-*-scp or dev-dumpdir.")
    parser.add_argument("--dev-segments",
                        default=None,
                        type=str,
                        help="kaldi-style segments file for validation.")
    parser.add_argument("--dev-dumpdir",
                        default=None,
                        type=str,
                        help="directory including development data. "
                        "you need to specify either dev-*-scp or dev-dumpdir.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--pretrain",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to load pretrained params. (default=\"\")")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help="checkpoint file path to resume training. (default=\"\")")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # check arguments
    if (args.train_feats_scp is not None and args.train_dumpdir is not None) or \
            (args.train_feats_scp is None and args.train_dumpdir is None):
        raise ValueError(
            "Please specify either --train-dumpdir or --train-*-scp.")
    if (args.dev_feats_scp is not None and args.dev_dumpdir is not None) or \
            (args.dev_feats_scp is None and args.dev_dumpdir is None):
        raise ValueError("Please specify either --dev-dumpdir or --dev-*-scp.")

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config["hop_size"] + \
            2 * config["generator_params"].get("aux_context_window", 0)
    else:
        mel_length_threshold = None
    if args.train_wav_scp is None or args.dev_wav_scp is None:
        if config["format"] == "hdf5":
            audio_query, mel_query = "*.h5", "*.h5"
            feat_query = config["feat_query"]

            def audio_load_fn(x):
                return read_hdf5(x, "wave")  # NOQA

            def mel_load_fn(x):
                return read_hdf5(x, feat_query)  # NOQA
        elif config["format"] == "npy":
            audio_query, mel_query = "*-wave.npy", "*-feats.npy"
            audio_load_fn = np.load
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
    if args.train_dumpdir is not None:
        train_dataset = AudioMelDataset(
            root_dir=args.train_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        train_dataset = AudioMelSCPDataset(
            wav_scp=args.train_wav_scp,
            feats_scp=args.train_feats_scp,
            segments=args.train_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of training files = {len(train_dataset)}.")
    if args.dev_dumpdir is not None:
        dev_dataset = AudioMelDataset(
            root_dir=args.dev_dumpdir,
            audio_query=audio_query,
            mel_query=mel_query,
            audio_load_fn=audio_load_fn,
            mel_load_fn=mel_load_fn,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    else:
        dev_dataset = AudioMelSCPDataset(
            wav_scp=args.dev_wav_scp,
            feats_scp=args.dev_feats_scp,
            segments=args.dev_segments,
            mel_length_threshold=mel_length_threshold,
            allow_cache=config.get("allow_cache", False),  # keep compatibility
        )
    logging.info(f"The number of development files = {len(dev_dataset)}.")
    dataset = {
        "train": train_dataset,
        "dev": dev_dataset,
    }

    # get data loader
    collater = Collater(
        batch_max_steps=config["batch_max_steps"],
        hop_size=config["hop_size"],
        # keep compatibility
        aux_context_window=config["generator_params"].get(
            "aux_context_window", 0),
        # keep compatibility
        use_noise_input=config.get(
            "generator_type", "ParallelWaveGANGenerator") != "MelGANGenerator",
        use_nemo_feature=config["use_nemo_feature"])
    sampler = {"train": None, "dev": None}
    if args.distributed:
        # setup sampler for distributed training
        from torch.utils.data.distributed import DistributedSampler
        sampler["train"] = DistributedSampler(
            dataset=dataset["train"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=True,
        )
        sampler["dev"] = DistributedSampler(
            dataset=dataset["dev"],
            num_replicas=args.world_size,
            rank=args.rank,
            shuffle=False,
        )
    data_loader = {
        "train":
        DataLoader(
            dataset=dataset["train"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["train"],
            pin_memory=config["pin_memory"],
        ),
        "dev":
        DataLoader(
            dataset=dataset["dev"],
            shuffle=False if args.distributed else True,
            collate_fn=collater,
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            sampler=sampler["dev"],
            pin_memory=config["pin_memory"],
        ),
    }

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        # keep compatibility
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"]).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"]).to(device),
    }
    criterion = {
        "stft":
        MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse": torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    if config["generator_params"]["out_channels"] > 1:
        criterion["pqmf"] = PQMF(
            subbands=config["generator_params"]["out_channels"],
            # keep compatibility
            **config.get("pqmf_params", {})).to(device)
    if config.get("use_subband_stft_loss", False):  # keep compatibility
        assert config["generator_params"]["out_channels"] > 1
        criterion["sub_stft"] = MultiResolutionSTFTLoss(
            **config["subband_stft_loss_params"]).to(device)
    generator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        # keep compatibility
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator":
        generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator":
        discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        # keep compatibility
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator":
        generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator":
        discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        sampler=sampler,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    if len(args.pretrain) != 0:
        trainer.load_checkpoint(args.pretrain, load_only_params=True)
        logging.info(f"Successfully load parameters from {args.pretrain}.")

    # resume from checkpoint
    if len(args.resume) != 0:
        trainer.load_checkpoint(args.resume)
        logging.info(f"Successfully resumed from {args.resume}.")

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps}steps.pkl"))
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
Esempio n. 4
0
    def __init__(
        self,
        repeats=2,
        window_sizes=[512, 1024, 2048, 4096],
        pqmf_params=[
            [1, None, None, None],
            [2, 62, 0.26700, 9.0],
            [4, 62, 0.14200, 9.0],
            [8, 62, 0.07949, 9.0],
        ],
        discriminator_params={
            "out_channels": 1,
            "kernel_sizes": [5, 3],
            "channels": 16,
            "max_downsample_channels": 512,
            "bias": True,
            "downsample_scales": [4, 4, 4, 1],
            "nonlinear_activation": "LeakyReLU",
            "nonlinear_activation_params": {
                "negative_slope": 0.2
            },
            "pad": "ReflectionPad1d",
            "pad_params": {},
        },
        use_weight_norm=True,
    ):
        """Initilize Style MelGAN discriminator.

        Args:
            repeats (int): Number of repititons to apply RWD.
            window_sizes (list): List of random window sizes.
            pqmf_params (list): List of list of Parameters for PQMF modules
            discriminator_params (dict): Parameters for base discriminator module.
            use_weight_nom (bool): Whether to apply weight normalization.

        """
        super().__init__()

        # window size check
        assert len(window_sizes) == len(pqmf_params)
        sizes = [ws // p[0] for ws, p in zip(window_sizes, pqmf_params)]
        assert len(window_sizes) == sum([sizes[0] == size for size in sizes])

        self.repeats = repeats
        self.window_sizes = window_sizes
        self.pqmfs = torch.nn.ModuleList()
        self.discriminators = torch.nn.ModuleList()
        for pqmf_param in pqmf_params:
            d_params = copy.deepcopy(discriminator_params)
            d_params["in_channels"] = pqmf_param[0]
            if pqmf_param[0] == 1:
                self.pqmfs += [torch.nn.Identity()]
            else:
                self.pqmfs += [PQMF(*pqmf_param)]
            self.discriminators += [BaseDiscriminator(**d_params)]

        # apply weight norm
        if use_weight_norm:
            self.apply_weight_norm()

        # reset parameters
        self.reset_parameters()
Esempio n. 5
0
def main():
    """Run decoding process."""
    parser = argparse.ArgumentParser(
        description=
        "Decode dumped features with trained Parallel WaveGAN Generator "
        "(See detail in parallel_wavegan/bin/decode.py).")
    parser.add_argument("--feats-scp",
                        "--scp",
                        default=None,
                        type=str,
                        help="kaldi-style feats.scp file. "
                        "you need to specify either feats-scp or dumpdir.")
    parser.add_argument("--dumpdir",
                        default=None,
                        type=str,
                        help="directory including feature files. "
                        "you need to specify either feats-scp or dumpdir.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save generated speech.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.dumpdir is not None) or \
            (args.feats_scp is None and args.dumpdir is None):
        raise ValueError("Please specify either --dumpdir or --feats-scp.")

    # get dataset
    if args.dumpdir is not None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*-feats.npy"
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        dataset = MelDataset(
            args.dumpdir,
            mel_query=mel_query,
            mel_load_fn=mel_load_fn,
            return_utt_id=True,
        )
    else:
        dataset = MelSCPDataset(
            feats_scp=args.feats_scp,
            return_utt_id=True,
        )
    logging.info(f"The number of features to be decoded = {len(dataset)}.")

    # setup
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"))
    model = model_class(**config["generator_params"])
    model.load_state_dict(
        torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
    logging.info(f"Loaded model parameters from {args.checkpoint}.")
    model.remove_weight_norm()
    model = model.eval().to(device)
    use_noise_input = not isinstance(model,
                                     parallel_wavegan.models.MelGANGenerator)
    pad_fn = torch.nn.ReplicationPad1d(config["generator_params"].get(
        "aux_context_window", 0))
    if config["generator_params"]["out_channels"] > 1:
        pqmf = PQMF(config["generator_params"]["out_channels"]).to(device)

    # start generation
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (utt_id, c) in enumerate(pbar, 1):
            # setup input
            x = ()
            if use_noise_input:
                z = torch.randn(1, 1, len(c) * config["hop_size"]).to(device)
                x += (z, )
            c = pad_fn(
                torch.tensor(c, dtype=torch.float).unsqueeze(0).transpose(
                    2, 1)).to(device)
            x += (c, )

            # generate
            start = time.time()
            if config["generator_params"]["out_channels"] == 1:
                y = model(*x).view(-1).cpu().numpy()
            else:
                y = pqmf.synthesis(model(*x)).view(-1).cpu().numpy()
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"), y,
                     config["sampling_rate"], "PCM_16")

    # report average RTF
    logging.info(
        f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})."
    )
Esempio n. 6
0
def main():

    parser = argparse.ArgumentParser(
        description="TTS decoder running RETURNN TTS and an MB-MelGAN vocoder")
    parser.add_argument("--returnn_config",
                        type=str,
                        help="RETURNN config file (.config)")
    parser.add_argument("--vocab_file",
                        type=str,
                        help="RETURNN vocab file (.pkl)")
    parser.add_argument("--pronunciation_lexicon",
                        type=str,
                        help="CMU style pronuncation lexicon")
    parser.add_argument("--pwg_config",
                        type=str,
                        help="ParallelWaveGAN config (.yaml)")
    parser.add_argument("--pwg_checkpoint",
                        type=str,
                        help="ParallelWaveGAN checkpoint (.pkl)")

    args = parser.parse_args()

    # Initialize RETURNN
    rnn.init(args.returnn_config)
    rnn.engine.use_search_flag = True  # enable search mode
    rnn.engine.init_network_from_config(rnn.config)

    returnn_vocab = Vocabulary(vocab_file=args.vocab_file, unknown_label=None)
    returnn_output_dict = {
        'output':
        rnn.engine.network.get_default_output_layer().output.placeholder
    }

    # Initialize PWG
    pwg_config = yaml.load(open(args.pwg_config), Loader=yaml.Loader)
    pyt_device = torch.device("cpu")
    generator = pwg_models.MelGANGenerator(**pwg_config['generator_params'])
    generator.load_state_dict(
        torch.load(args.pwg_checkpoint,
                   map_location="cpu")["model"]["generator"])
    generator.remove_weight_norm()
    pwg_model = generator.eval().to(pyt_device)
    pwg_pad_fn = torch.nn.ReplicationPad1d(pwg_config["generator_params"].get(
        "aux_context_window", 0))
    pwg_pqmf = PQMF(
        pwg_config["generator_params"]["out_channels"]).to(pyt_device)

    # load a CMU dict style pronunciation table
    pronunciation_dictionary = {}
    with open(args.pronunciation_lexicon, "rt") as lexicon:
        for lexicon_entry in lexicon.readlines():
            word, phonemes = lexicon_entry.strip().split(" ", maxsplit=1)
            pronunciation_dictionary[word] = phonemes.split(" ")

    # Tokenizer perl command
    tokenizer = [
        "perl", "./scripts/tokenizer/tokenizer.perl", "-l", "en", "-no-escape"
    ]

    audios = []

    for line in sys.stdin.readlines():
        line = line.strip().lower()
        # run perl tokenizer as external script
        p = subprocess.Popen(tokenizer,
                             stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE)
        line = p.communicate(
            input=line.encode("UTF-8"))[0].decode("UTF-8").strip()
        p.terminate()
        print(line)

        # apply num2wordsn and pronunciation dict
        words = list(map(number_convert, line.split(" ")))
        print(words)
        phoneme_sequence = " _ ".join([
            " ".join(pronunciation_dictionary[w]) for w in words
            if w in pronunciation_dictionary.keys()
        ])
        phoneme_sequence += " _ ~"

        try:
            classes = numpy.asarray(returnn_vocab.get_seq(phoneme_sequence),
                                    dtype="int32")
            feed_dict = {'classes': classes}
            dataset = StaticDataset([feed_dict],
                                    output_dim={'classes': (77, 1)})
            result = rnn.engine.run_single(dataset, 0, returnn_output_dict)
        except Exception as e:
            print(e)
            raise e

        feature_data = numpy.squeeze(result['output']).T
        print(feature_data.shape)

        with torch.no_grad():
            input_features = pwg_pad_fn(
                torch.from_numpy(feature_data).unsqueeze(0)).to(pyt_device)
            audio_waveform = pwg_pqmf.synthesis(
                pwg_model(input_features)).view(-1).cpu().numpy()

        audios.append(
            numpy.asarray(audio_waveform * (2**15 - 1),
                          dtype="int16").tobytes())

    for i, audio in enumerate(audios):
        wave_writer = wave.open("out_%i.wav" % i, "wb")
        wave_writer.setnchannels(1)
        wave_writer.setframerate(16000)
        wave_writer.setsampwidth(2)
        wave_writer.writeframes(audio)
        wave_writer.close()
Esempio n. 7
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train Parallel WaveGAN (See detail in parallel_wavegan/bin/train.py)."
    )
    parser.add_argument('--input_training_file', default='filelists/train.txt')
    parser.add_argument('--input_validation_file', default='filelists/val.txt')
    parser.add_argument('--checkpoint_path', default='parallel_wavegan/outdir')
    parser.add_argument(
        '--config',
        default='parallel_wavegan/config/multi_band_melgan.v2.yaml')
    parser.add_argument('--checkpoint_interval', default=0, type=int)
    parser.add_argument('--max_steps', default=1000000, type=int)
    parser.add_argument('--n_models_to_keep', default=1, type=int)
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    parser.add_argument(
        "--rank",
        "--local_rank",
        default=0,
        type=int,
        help="rank for distributed training. no need to explictly specify.")
    args = parser.parse_args()

    args.distributed = False
    if not torch.cuda.is_available():
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
        # effective when using fixed size inputs
        # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_device(args.rank)
        # setup for distributed training
        # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
        if "WORLD_SIZE" in os.environ:
            args.world_size = int(os.environ["WORLD_SIZE"])
            args.distributed = args.world_size > 1
        if args.distributed:
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method="env://")

    # suppress logging for distributed training
    if args.rank != 0:
        sys.stdout = open(os.devnull, "w")

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = parallel_wavegan.__version__  # add version info
    if config["checkpoint_interval"] > 0:
        config["save_interval_steps"] = config["checkpoint_interval"]
    if config["max_steps"] and config["max_steps"] > 0:
        config["train_max_steps"] = config["max_steps"]
    with open(os.path.join(args.checkpoint_path, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    data_loader = {}
    data_loader["train"], sampler = from_path(config["input_training_file"],
                                              config)
    data_loader["dev"] = from_path(config["input_validation_file"], config)[0]

    # define models and optimizers
    generator_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    discriminator_class = getattr(
        parallel_wavegan.models,
        config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
    )
    model = {
        "generator":
        generator_class(**config["generator_params"]).to(device),
        "discriminator":
        discriminator_class(**config["discriminator_params"]).to(device),
    }
    criterion = {# reconstruction loss
        "stft": MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device),
        "mse":  torch.nn.MSELoss().to(device),
    }
    if config.get("use_feat_match_loss", False):  # keep compatibility
        criterion["l1"] = torch.nn.L1Loss().to(device)
    if config["generator_params"]["out_channels"] > 1:
        criterion["pqmf"] = PQMF(
            subbands=config["generator_params"]["out_channels"],
            # keep compatibility
            **config.get("pqmf_params", {})).to(device)
    if config.get("use_subband_stft_loss", False):  # keep compatibility
        assert config["generator_params"]["out_channels"] > 1
        criterion["sub_stft"] = MultiResolutionSTFTLoss(
            **config["subband_stft_loss_params"]).to(device)
    generator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        config.get("generator_optimizer_type", "RAdam"),
    )
    discriminator_optimizer_class = getattr(
        parallel_wavegan.optimizers,
        config.get("discriminator_optimizer_type", "RAdam"),
    )
    optimizer = {
        "generator":
        generator_optimizer_class(
            model["generator"].parameters(),
            **config["generator_optimizer_params"],
        ),
        "discriminator":
        discriminator_optimizer_class(
            model["discriminator"].parameters(),
            **config["discriminator_optimizer_params"],
        ),
    }
    generator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        config.get("generator_scheduler_type", "StepLR"),
    )
    discriminator_scheduler_class = getattr(
        torch.optim.lr_scheduler,
        config.get("discriminator_scheduler_type", "StepLR"),
    )
    scheduler = {
        "generator":
        generator_scheduler_class(
            optimizer=optimizer["generator"],
            **config["generator_scheduler_params"],
        ),
        "discriminator":
        discriminator_scheduler_class(
            optimizer=optimizer["discriminator"],
            **config["discriminator_scheduler_params"],
        ),
    }
    if args.distributed:
        # wrap model for distributed training
        try:
            from apex.parallel import DistributedDataParallel
        except ImportError:
            raise ImportError(
                "apex is not installed. please check https://github.com/NVIDIA/apex."
            )
        model["generator"] = DistributedDataParallel(model["generator"])
        model["discriminator"] = DistributedDataParallel(
            model["discriminator"])
    logging.info(model["generator"])
    logging.info(model["discriminator"])

    # define trainer
    trainer = Trainer(
        steps=0,
        epochs=0,
        data_loader=data_loader,
        sampler=sampler,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        config=config,
        device=device,
    )

    # load pretrained parameters from checkpoint
    trainer.load_checkpoint()

    # run training loop
    try:
        trainer.run()
    except KeyboardInterrupt:
        logging.info(f"Saving checkpoint in 10...")
        for i in range(9, -1, -1):
            print(f"{i}...")
        trainer.save_checkpoint(
            os.path.join(config["outdir"],
                         f"checkpoint-{trainer.steps:08}steps.pkl"))
        logging.info(
            f"Successfully saved checkpoint @ {trainer.steps:08}steps.")
Esempio n. 8
0
def load_model(checkpoint, config=None, stats=None):
    """Load trained model.

    Args:
        checkpoint (str): Checkpoint path.
        config (dict): Configuration dict.
        stats (str): Statistics file path.

    Return:
        torch.nn.Module: Model instance.

    """
    # load config if not provided
    if config is None:
        dirname = os.path.dirname(checkpoint)
        config = os.path.join(dirname, "config.yml")
        with open(config) as f:
            config = yaml.load(f, Loader=yaml.Loader)

    # lazy load for circular error
    import parallel_wavegan.models

    # get model and load parameters
    model_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"),
    )
    # workaround for typo #295
    generator_params = {
        k.replace("upsample_kernal_sizes", "upsample_kernel_sizes"): v
        for k, v in config["generator_params"].items()
    }
    model = model_class(**generator_params)
    model.load_state_dict(
        torch.load(checkpoint, map_location="cpu")["model"]["generator"]
    )

    # check stats existence
    if stats is None:
        dirname = os.path.dirname(checkpoint)
        if config["format"] == "hdf5":
            ext = "h5"
        else:
            ext = "npy"
        if os.path.exists(os.path.join(dirname, f"stats.{ext}")):
            stats = os.path.join(dirname, f"stats.{ext}")

    # load stats
    if stats is not None:
        model.register_stats(stats)

    # add pqmf if needed
    if config["generator_params"]["out_channels"] > 1:
        # lazy load for circular error
        from parallel_wavegan.layers import PQMF

        pqmf_params = {}
        if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"):
            # For compatibility, here we set default values in version <= 0.4.2
            pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0)
        model.pqmf = PQMF(
            subbands=config["generator_params"]["out_channels"],
            **config.get("pqmf_params", pqmf_params),
        )

    return model
def main():
    device = torch.device("cuda")
    torch.manual_seed(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Define the seq2seq model
    model_dir = "/home/administrator/espnet/egs/tmp/tts1/exp/char_train_no_dev_pytorch_train_pytorch_transformer.v3.single/"
    model_path = os.path.join(model_dir, "results/model.loss.best")
    idim, odim, train_args = get_model_conf(model_path)
    print(f"Input dimension: {idim}, output dimension: {odim}")

    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    torch_load(model_path, model)
    model = model.eval().to(device)
    inference_args = argparse.Namespace(
        **{
            "threshold": 0.5,
            "minlenratio": 0.0,
            "maxlenratio": 10.0,
            # Only for Tacotron 2
            "use_attention_constraint": True,
            "backward_window": 1,
            "forward_window": 3,
            # Only for fastspeech (lower than 1.0 is faster speech, higher than 1.0 is slower speech)
            "fastspeech_alpha": 1.0,
        })

    # Define ParallelWaveGAN
    pwgan_path = (
        "/home/administrator/espnet/utils/parallelwavegan/checkpoints/checkpoint-1000000steps.pkl"
    )
    pwgan_conf = "/home/administrator/espnet/utils/parallelwavegan/config.yml"
    with open(pwgan_conf) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    vocoder_class = config.get("generator_type", "ParallelWaveGANGenerator")
    vocoder = getattr(parallel_wavegan.models,
                      vocoder_class)(**config["generator_params"])
    vocoder.load_state_dict(
        torch.load(pwgan_path, map_location="cpu")["model"]["generator"])
    vocoder.remove_weight_norm()
    vocoder.eval()
    pad_fn = torch.nn.ReplicationPad1d(config["generator_params"].get(
        "aux_context_window", 0))
    use_noise_input = vocoder_class == "ParallelWaveGANGenerator"
    if config["generator_params"]["out_channels"] > 1:
        from parallel_wavegan.layers import PQMF

        pqmf = PQMF(config["generator_params"]["out_channels"]).to(device)
    vocoder = vocoder.cuda()

    # Define the dictionary
    if model_dir.find("char") != -1:
        trans_type = "char"
        dict_path = "/home/administrator/espnet/egs/tmp/tts1/data/lang_1char/char_train_no_dev_units.txt"
    elif model_dir.find("phn") != -1:
        trans_type = "phn"
        dict_path = "/home/administrator/espnet/egs/tmp/tts1/data/lang_1phn/phn_train_no_dev_units.txt"

    with open(dict_path, encoding="utf-8") as f:
        lines = f.readlines()
    lines = [line.replace("\n", "").split(" ") for line in lines]
    char_to_id = {c: int(i) for c, i in lines}
    id_to_char = {int(i): c for c, i in lines}

    # Get input texts
    input_texts = [
        "트랜스포머와 패럴렐 웨이브갠 기반 엔드투엔드 음성합성기 데모입니다.",
        "원하는 문장을 입력하세요.",
    ]

    total_time = 0
    syn_time = 0
    idx = 0
    with torch.no_grad():
        for input_text in input_texts:
            start1 = time.time()

            # text-to-sequence
            idseq = np.array(
                text_to_sequence(input_text, char_to_id, "korean_cleaners"))
            idseq = torch.autograd.Variable(
                torch.from_numpy(idseq)).cuda().long()

            # Transformer inference
            start2 = time.time()
            y_pred, _, attn = model.inference(idseq, inference_args)
            print("mel_outputs and attentions are of {} and {}".format(
                y_pred.shape,
                attn.shape))  # [T_out, 80] & [# layers, # heads, T_out, T_in]

            # define function for plot prob and att_ws
            def _plot_and_save(array, figname, figsize=(6, 4), dpi=150):
                import matplotlib.pyplot as plt

                shape = array.shape
                if len(shape) == 1:
                    # for eos probability
                    plt.figure(figsize=figsize, dpi=dpi)
                    plt.plot(array)
                    plt.xlabel("Frame")
                    plt.ylabel("Probability")
                    plt.ylim([0, 1])
                elif len(shape) == 2:
                    # for tacotron 2 attention weights, whose shape is (out_length, in_length)
                    plt.figure(figsize=figsize, dpi=dpi)
                    plt.imshow(array, aspect="auto")
                    plt.xlabel("Input")
                    plt.ylabel("Output")
                elif len(shape) == 4:
                    # for transformer attention weights,
                    # whose shape is (#leyers, #heads, out_length, in_length)
                    plt.figure(figsize=(figsize[0] * shape[0],
                                        figsize[1] * shape[1]),
                               dpi=dpi)
                    for idx1, xs in enumerate(array):
                        for idx2, x in enumerate(xs, 1):
                            plt.subplot(shape[0], shape[1],
                                        idx1 * shape[1] + idx2)
                            plt.imshow(x, aspect="auto")
                            plt.xlabel("Input")
                            plt.ylabel("Output")
                else:
                    raise NotImplementedError(
                        "Support only from 1D to 4D array.")
                plt.tight_layout()
                if not os.path.exists(os.path.dirname(figname)):
                    # NOTE: exist_ok = True is needed for parallel process decoding
                    os.makedirs(os.path.dirname(figname), exist_ok=True)
                plt.savefig(figname)
                plt.close()

            # attention plot
            attnname = os.path.join(model_dir, "attention_{}.png".format(idx))
            _plot_and_save(attn.cpu().numpy(), attnname)

            # synthesize
            audio_pred = vocoder.inference(y_pred).view(-1)
            audio = audio_pred.cpu().numpy()
            audio *= 32767 / max(0.01, np.max(np.abs(audio)))
            wavfile.write(
                os.path.join(model_dir, "sample_{}.wav".format(idx)),
                22050,
                audio.astype(np.int16),
            )
            total_time += time.time() - start1
            idx += 1
            syn_time += audio.size / 22050
    print(
        f"Generated {idx} waveforms that correspond to {syn_time} seconds in {total_time} seconds."
    )