def test_pqmf(subbands): pqmf = PQMF(subbands) x = torch.randn(1, 1, subbands * 32) y = pqmf.analysis(x) assert y.shape[2] * subbands == x.shape[2] x_hat = pqmf.synthesis(y) assert x.shape[2] == x_hat.shape[2]
def main(): """Run decoding process.""" parser = argparse.ArgumentParser( description= "Decode dumped features with trained Parallel WaveGAN Generator " "(See detail in parallel_wavegan/bin/decode.py).") parser.add_argument("--feats-scp", "--scp", default=None, type=str, help="kaldi-style feats.scp file. " "you need to specify either feats-scp or dumpdir.") parser.add_argument("--dumpdir", default=None, type=str, help="directory including feature files. " "you need to specify either feats-scp or dumpdir.") parser.add_argument("--outdir", type=str, required=True, help="directory to save generated speech.") parser.add_argument("--checkpoint", type=str, required=True, help="checkpoint file to be loaded.") parser.add_argument( "--config", default=None, type=str, help="yaml format configuration file. if not explicitly provided, " "it will be searched in the checkpoint directory. (default=None)") parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)") args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") else: logging.basicConfig( level=logging.WARN, format= "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load config if args.config is None: dirname = os.path.dirname(args.checkpoint) args.config = os.path.join(dirname, "config.yml") with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) # check arguments if (args.feats_scp is not None and args.dumpdir is not None) or \ (args.feats_scp is None and args.dumpdir is None): raise ValueError("Please specify either --dumpdir or --feats-scp.") # get dataset if args.dumpdir is not None: if config["format"] == "hdf5": mel_query = "*.h5" mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": mel_query = "*-feats.npy" mel_load_fn = np.load else: raise ValueError("support only hdf5 or npy format.") dataset = MelDataset( args.dumpdir, mel_query=mel_query, mel_load_fn=mel_load_fn, return_utt_id=True, ) else: dataset = MelSCPDataset( feats_scp=args.feats_scp, return_utt_id=True, ) logging.info(f"The number of features to be decoded = {len(dataset)}.") # setup if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model_class = getattr( parallel_wavegan.models, config.get("generator_type", "ParallelWaveGANGenerator")) model = model_class(**config["generator_params"]) model.load_state_dict( torch.load(args.checkpoint, map_location="cpu")["model"]["generator"]) logging.info(f"Loaded model parameters from {args.checkpoint}.") model.remove_weight_norm() model = model.eval().to(device) use_noise_input = not isinstance(model, parallel_wavegan.models.MelGANGenerator) pad_fn = torch.nn.ReplicationPad1d(config["generator_params"].get( "aux_context_window", 0)) if config["generator_params"]["out_channels"] > 1: pqmf = PQMF(config["generator_params"]["out_channels"]).to(device) # start generation total_rtf = 0.0 with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar: for idx, (utt_id, c) in enumerate(pbar, 1): # setup input x = () if use_noise_input: z = torch.randn(1, 1, len(c) * config["hop_size"]).to(device) x += (z, ) c = pad_fn( torch.tensor(c, dtype=torch.float).unsqueeze(0).transpose( 2, 1)).to(device) x += (c, ) # generate start = time.time() if config["generator_params"]["out_channels"] == 1: y = model(*x).view(-1).cpu().numpy() else: y = pqmf.synthesis(model(*x)).view(-1).cpu().numpy() rtf = (time.time() - start) / (len(y) / config["sampling_rate"]) pbar.set_postfix({"RTF": rtf}) total_rtf += rtf # save as PCM 16 bit wav file sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"), y, config["sampling_rate"], "PCM_16") # report average RTF logging.info( f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})." )
def main(): parser = argparse.ArgumentParser( description="TTS decoder running RETURNN TTS and an MB-MelGAN vocoder") parser.add_argument("--returnn_config", type=str, help="RETURNN config file (.config)") parser.add_argument("--vocab_file", type=str, help="RETURNN vocab file (.pkl)") parser.add_argument("--pronunciation_lexicon", type=str, help="CMU style pronuncation lexicon") parser.add_argument("--pwg_config", type=str, help="ParallelWaveGAN config (.yaml)") parser.add_argument("--pwg_checkpoint", type=str, help="ParallelWaveGAN checkpoint (.pkl)") args = parser.parse_args() # Initialize RETURNN rnn.init(args.returnn_config) rnn.engine.use_search_flag = True # enable search mode rnn.engine.init_network_from_config(rnn.config) returnn_vocab = Vocabulary(vocab_file=args.vocab_file, unknown_label=None) returnn_output_dict = { 'output': rnn.engine.network.get_default_output_layer().output.placeholder } # Initialize PWG pwg_config = yaml.load(open(args.pwg_config), Loader=yaml.Loader) pyt_device = torch.device("cpu") generator = pwg_models.MelGANGenerator(**pwg_config['generator_params']) generator.load_state_dict( torch.load(args.pwg_checkpoint, map_location="cpu")["model"]["generator"]) generator.remove_weight_norm() pwg_model = generator.eval().to(pyt_device) pwg_pad_fn = torch.nn.ReplicationPad1d(pwg_config["generator_params"].get( "aux_context_window", 0)) pwg_pqmf = PQMF( pwg_config["generator_params"]["out_channels"]).to(pyt_device) # load a CMU dict style pronunciation table pronunciation_dictionary = {} with open(args.pronunciation_lexicon, "rt") as lexicon: for lexicon_entry in lexicon.readlines(): word, phonemes = lexicon_entry.strip().split(" ", maxsplit=1) pronunciation_dictionary[word] = phonemes.split(" ") # Tokenizer perl command tokenizer = [ "perl", "./scripts/tokenizer/tokenizer.perl", "-l", "en", "-no-escape" ] audios = [] for line in sys.stdin.readlines(): line = line.strip().lower() # run perl tokenizer as external script p = subprocess.Popen(tokenizer, stdin=subprocess.PIPE, stdout=subprocess.PIPE) line = p.communicate( input=line.encode("UTF-8"))[0].decode("UTF-8").strip() p.terminate() print(line) # apply num2wordsn and pronunciation dict words = list(map(number_convert, line.split(" "))) print(words) phoneme_sequence = " _ ".join([ " ".join(pronunciation_dictionary[w]) for w in words if w in pronunciation_dictionary.keys() ]) phoneme_sequence += " _ ~" try: classes = numpy.asarray(returnn_vocab.get_seq(phoneme_sequence), dtype="int32") feed_dict = {'classes': classes} dataset = StaticDataset([feed_dict], output_dim={'classes': (77, 1)}) result = rnn.engine.run_single(dataset, 0, returnn_output_dict) except Exception as e: print(e) raise e feature_data = numpy.squeeze(result['output']).T print(feature_data.shape) with torch.no_grad(): input_features = pwg_pad_fn( torch.from_numpy(feature_data).unsqueeze(0)).to(pyt_device) audio_waveform = pwg_pqmf.synthesis( pwg_model(input_features)).view(-1).cpu().numpy() audios.append( numpy.asarray(audio_waveform * (2**15 - 1), dtype="int16").tobytes()) for i, audio in enumerate(audios): wave_writer = wave.open("out_%i.wav" % i, "wb") wave_writer.setnchannels(1) wave_writer.setframerate(16000) wave_writer.setsampwidth(2) wave_writer.writeframes(audio) wave_writer.close()