Exemplo n.º 1
0
def test_pqmf(subbands):
    pqmf = PQMF(subbands)
    x = torch.randn(1, 1, subbands * 32)
    y = pqmf.analysis(x)
    assert y.shape[2] * subbands == x.shape[2]
    x_hat = pqmf.synthesis(y)
    assert x.shape[2] == x_hat.shape[2]
Exemplo n.º 2
0
def main():
    """Run decoding process."""
    parser = argparse.ArgumentParser(
        description=
        "Decode dumped features with trained Parallel WaveGAN Generator "
        "(See detail in parallel_wavegan/bin/decode.py).")
    parser.add_argument("--feats-scp",
                        "--scp",
                        default=None,
                        type=str,
                        help="kaldi-style feats.scp file. "
                        "you need to specify either feats-scp or dumpdir.")
    parser.add_argument("--dumpdir",
                        default=None,
                        type=str,
                        help="directory including feature files. "
                        "you need to specify either feats-scp or dumpdir.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save generated speech.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)")
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)")
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    if args.config is None:
        dirname = os.path.dirname(args.checkpoint)
        args.config = os.path.join(dirname, "config.yml")
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    # check arguments
    if (args.feats_scp is not None and args.dumpdir is not None) or \
            (args.feats_scp is None and args.dumpdir is None):
        raise ValueError("Please specify either --dumpdir or --feats-scp.")

    # get dataset
    if args.dumpdir is not None:
        if config["format"] == "hdf5":
            mel_query = "*.h5"
            mel_load_fn = lambda x: read_hdf5(x, "feats")  # NOQA
        elif config["format"] == "npy":
            mel_query = "*-feats.npy"
            mel_load_fn = np.load
        else:
            raise ValueError("support only hdf5 or npy format.")
        dataset = MelDataset(
            args.dumpdir,
            mel_query=mel_query,
            mel_load_fn=mel_load_fn,
            return_utt_id=True,
        )
    else:
        dataset = MelSCPDataset(
            feats_scp=args.feats_scp,
            return_utt_id=True,
        )
    logging.info(f"The number of features to be decoded = {len(dataset)}.")

    # setup
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model_class = getattr(
        parallel_wavegan.models,
        config.get("generator_type", "ParallelWaveGANGenerator"))
    model = model_class(**config["generator_params"])
    model.load_state_dict(
        torch.load(args.checkpoint, map_location="cpu")["model"]["generator"])
    logging.info(f"Loaded model parameters from {args.checkpoint}.")
    model.remove_weight_norm()
    model = model.eval().to(device)
    use_noise_input = not isinstance(model,
                                     parallel_wavegan.models.MelGANGenerator)
    pad_fn = torch.nn.ReplicationPad1d(config["generator_params"].get(
        "aux_context_window", 0))
    if config["generator_params"]["out_channels"] > 1:
        pqmf = PQMF(config["generator_params"]["out_channels"]).to(device)

    # start generation
    total_rtf = 0.0
    with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
        for idx, (utt_id, c) in enumerate(pbar, 1):
            # setup input
            x = ()
            if use_noise_input:
                z = torch.randn(1, 1, len(c) * config["hop_size"]).to(device)
                x += (z, )
            c = pad_fn(
                torch.tensor(c, dtype=torch.float).unsqueeze(0).transpose(
                    2, 1)).to(device)
            x += (c, )

            # generate
            start = time.time()
            if config["generator_params"]["out_channels"] == 1:
                y = model(*x).view(-1).cpu().numpy()
            else:
                y = pqmf.synthesis(model(*x)).view(-1).cpu().numpy()
            rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
            pbar.set_postfix({"RTF": rtf})
            total_rtf += rtf

            # save as PCM 16 bit wav file
            sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"), y,
                     config["sampling_rate"], "PCM_16")

    # report average RTF
    logging.info(
        f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})."
    )
Exemplo n.º 3
0
def main():

    parser = argparse.ArgumentParser(
        description="TTS decoder running RETURNN TTS and an MB-MelGAN vocoder")
    parser.add_argument("--returnn_config",
                        type=str,
                        help="RETURNN config file (.config)")
    parser.add_argument("--vocab_file",
                        type=str,
                        help="RETURNN vocab file (.pkl)")
    parser.add_argument("--pronunciation_lexicon",
                        type=str,
                        help="CMU style pronuncation lexicon")
    parser.add_argument("--pwg_config",
                        type=str,
                        help="ParallelWaveGAN config (.yaml)")
    parser.add_argument("--pwg_checkpoint",
                        type=str,
                        help="ParallelWaveGAN checkpoint (.pkl)")

    args = parser.parse_args()

    # Initialize RETURNN
    rnn.init(args.returnn_config)
    rnn.engine.use_search_flag = True  # enable search mode
    rnn.engine.init_network_from_config(rnn.config)

    returnn_vocab = Vocabulary(vocab_file=args.vocab_file, unknown_label=None)
    returnn_output_dict = {
        'output':
        rnn.engine.network.get_default_output_layer().output.placeholder
    }

    # Initialize PWG
    pwg_config = yaml.load(open(args.pwg_config), Loader=yaml.Loader)
    pyt_device = torch.device("cpu")
    generator = pwg_models.MelGANGenerator(**pwg_config['generator_params'])
    generator.load_state_dict(
        torch.load(args.pwg_checkpoint,
                   map_location="cpu")["model"]["generator"])
    generator.remove_weight_norm()
    pwg_model = generator.eval().to(pyt_device)
    pwg_pad_fn = torch.nn.ReplicationPad1d(pwg_config["generator_params"].get(
        "aux_context_window", 0))
    pwg_pqmf = PQMF(
        pwg_config["generator_params"]["out_channels"]).to(pyt_device)

    # load a CMU dict style pronunciation table
    pronunciation_dictionary = {}
    with open(args.pronunciation_lexicon, "rt") as lexicon:
        for lexicon_entry in lexicon.readlines():
            word, phonemes = lexicon_entry.strip().split(" ", maxsplit=1)
            pronunciation_dictionary[word] = phonemes.split(" ")

    # Tokenizer perl command
    tokenizer = [
        "perl", "./scripts/tokenizer/tokenizer.perl", "-l", "en", "-no-escape"
    ]

    audios = []

    for line in sys.stdin.readlines():
        line = line.strip().lower()
        # run perl tokenizer as external script
        p = subprocess.Popen(tokenizer,
                             stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE)
        line = p.communicate(
            input=line.encode("UTF-8"))[0].decode("UTF-8").strip()
        p.terminate()
        print(line)

        # apply num2wordsn and pronunciation dict
        words = list(map(number_convert, line.split(" ")))
        print(words)
        phoneme_sequence = " _ ".join([
            " ".join(pronunciation_dictionary[w]) for w in words
            if w in pronunciation_dictionary.keys()
        ])
        phoneme_sequence += " _ ~"

        try:
            classes = numpy.asarray(returnn_vocab.get_seq(phoneme_sequence),
                                    dtype="int32")
            feed_dict = {'classes': classes}
            dataset = StaticDataset([feed_dict],
                                    output_dim={'classes': (77, 1)})
            result = rnn.engine.run_single(dataset, 0, returnn_output_dict)
        except Exception as e:
            print(e)
            raise e

        feature_data = numpy.squeeze(result['output']).T
        print(feature_data.shape)

        with torch.no_grad():
            input_features = pwg_pad_fn(
                torch.from_numpy(feature_data).unsqueeze(0)).to(pyt_device)
            audio_waveform = pwg_pqmf.synthesis(
                pwg_model(input_features)).view(-1).cpu().numpy()

        audios.append(
            numpy.asarray(audio_waveform * (2**15 - 1),
                          dtype="int16").tobytes())

    for i, audio in enumerate(audios):
        wave_writer = wave.open("out_%i.wav" % i, "wb")
        wave_writer.setnchannels(1)
        wave_writer.setframerate(16000)
        wave_writer.setsampwidth(2)
        wave_writer.writeframes(audio)
        wave_writer.close()