コード例 #1
0
def test_multi_band_melgan(dict_g):
    args_g = make_multi_band_melgan_generator_args(**dict_g)
    args_g = MultiBandMelGANGeneratorConfig(**args_g)
    generator = TFMelGANGenerator(args_g, name="multi_band_melgan")
    generator._build()

    pqmf = TFPQMF(args_g, name="pqmf")

    fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
    fake_y = tf.random.uniform(shape=[1, 100 * 256, 1], dtype=tf.float32)
    y_hat_subbands = generator(fake_mels)

    y_hat = pqmf.synthesis(y_hat_subbands)
    y_subbands = pqmf.analysis(fake_y)

    assert np.shape(y_subbands) == np.shape(y_hat_subbands)
    assert np.shape(fake_y) == np.shape(y_hat)
コード例 #2
0
def test_melgan_trainable(dict_g, dict_d, dict_loss):
    batch_size = 4
    batch_length = 4096
    args_g = make_melgan_generator_args(**dict_g)
    args_d = make_melgan_discriminator_args(**dict_d)

    args_g = MelGANGeneratorConfig(**args_g)
    args_d = MelGANDiscriminatorConfig(**args_d)

    generator = TFMelGANGenerator(args_g)
    discriminator = TFMelGANMultiScaleDiscriminator(args_d)
コード例 #3
0
ファイル: melgan.py プロジェクト: taverok/tts_melgan
def get_model():
    with open(get_weight_path('melgan_config.yml')) as f:
        config = yaml.load(f, Loader=yaml.Loader)

    config = MelGANGeneratorConfig(**config["generator_params"])
    melgan = TFMelGANGenerator(config=config, name="melgan_generator")
    melgan._build()
    melgan.load_weights(get_weight_path('melgan-1M6.h5'))

    return melgan
コード例 #4
0
ファイル: synthesys.py プロジェクト: sce-tts/tts-server
def load_mb_melgan(config_path, model_path):
    with open(config_path) as f:
        raw_config = yaml.load(f, Loader=yaml.Loader)
        mb_melgan_config = MultiBandMelGANGeneratorConfig(
            **raw_config["generator_params"])
        mb_melgan = TFMelGANGenerator(config=mb_melgan_config,
                                      name="melgan_generator")
        mb_melgan._build()
        mb_melgan.load_weights(model_path)
        pqmf = TFPQMF(config=mb_melgan_config, name="pqmf")
    return (mb_melgan, pqmf)
コード例 #5
0
ファイル: tts_utils.py プロジェクト: allen-n/EPUB-to-MP3
 def _load_melgan(self, path='./model_files/melgan'):
     # initialize melgan model for vocoding
     config = os.path.join(path, 'config.yml')
     with open(config) as f:
         melgan_config = yaml.load(f, Loader=yaml.Loader)
     melgan_config = MelGANGeneratorConfig(
         **melgan_config["generator_params"])
     melgan = TFMelGANGenerator(config=melgan_config,
                                name='melgan_generator')
     melgan._build()
     weights = os.path.join(path, 'generator-1670000.h5')
     melgan.load_weights(weights)
     return melgan
コード例 #6
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description=
        "Train MelGAN (See detail in tensorflow_tts/bin/train-melgan.py)")
    parser.add_argument(
        "--train-dir",
        default=None,
        type=str,
        help="directory including training data. ",
    )
    parser.add_argument(
        "--dev-dir",
        default=None,
        type=str,
        help="directory including development data. ",
    )
    parser.add_argument("--use-norm",
                        default=1,
                        type=int,
                        help="use norm mels for training or raw.")
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save checkpoints.")
    parser.add_argument("--config",
                        type=str,
                        required=True,
                        help="yaml format configuration file.")
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to resume training. (default="")',
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    parser.add_argument(
        "--generator_mixed_precision",
        default=0,
        type=int,
        help="using mixed precision for generator or not.",
    )
    parser.add_argument(
        "--discriminator_mixed_precision",
        default=0,
        type=int,
        help="using mixed precision for discriminator or not.",
    )
    args = parser.parse_args()

    # return strategy
    STRATEGY = return_strategy()

    # set mixed precision config
    if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1:
        tf.config.optimizer.set_experimental_options(
            {"auto_mixed_precision": True})

    args.generator_mixed_precision = bool(args.generator_mixed_precision)
    args.discriminator_mixed_precision = bool(
        args.discriminator_mixed_precision)

    args.use_norm = bool(args.use_norm)

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # check arguments
    if args.train_dir is None:
        raise ValueError("Please specify --train-dir")
    if args.dev_dir is None:
        raise ValueError("Please specify either --valid-dir")

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = tensorflow_tts.__version__
    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["batch_max_steps"] // config[
            "hop_size"] + 2 * config["melgan_generator_params"].get(
                "aux_context_window", 0)
    else:
        mel_length_threshold = None

    if config["format"] == "npy":
        audio_query = "*-wave.npy"
        mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
        audio_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("Only npy are supported.")

    # define train/valid dataset
    train_dataset = AudioMelDataset(
        root_dir=args.train_dir,
        audio_query=audio_query,
        mel_query=mel_query,
        audio_load_fn=audio_load_fn,
        mel_load_fn=mel_load_fn,
        mel_length_threshold=mel_length_threshold,
    ).create(
        is_shuffle=config["is_shuffle"],
        map_fn=lambda items: collater(
            items,
            batch_max_steps=tf.constant(config["batch_max_steps"],
                                        dtype=tf.int32),
            hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
        ),
        allow_cache=config["allow_cache"],
        batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
    )

    valid_dataset = AudioMelDataset(
        root_dir=args.dev_dir,
        audio_query=audio_query,
        mel_query=mel_query,
        audio_load_fn=audio_load_fn,
        mel_load_fn=mel_load_fn,
        mel_length_threshold=mel_length_threshold,
    ).create(
        is_shuffle=config["is_shuffle"],
        map_fn=lambda items: collater(
            items,
            batch_max_steps=tf.constant(config["batch_max_steps_valid"],
                                        dtype=tf.int32),
            hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
        ),
        allow_cache=config["allow_cache"],
        batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
    )

    # define trainer
    trainer = MelganTrainer(
        steps=0,
        epochs=0,
        config=config,
        strategy=STRATEGY,
        is_generator_mixed_precision=args.generator_mixed_precision,
        is_discriminator_mixed_precision=args.discriminator_mixed_precision,
    )

    # define generator and discriminator
    with STRATEGY.scope():
        generator = TFMelGANGenerator(
            MELGAN_CONFIG.MelGANGeneratorConfig(
                **config["melgan_generator_params"]),
            name="melgan_generator",
        )

        discriminator = TFMelGANMultiScaleDiscriminator(
            MELGAN_CONFIG.MelGANDiscriminatorConfig(
                **config["melgan_discriminator_params"]),
            name="melgan_discriminator",
        )

        # dummy input to build model.
        fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
        y_hat = generator(fake_mels)
        discriminator(y_hat)

        generator.summary()
        discriminator.summary()

        gen_optimizer = tf.keras.optimizers.Adam(
            **config["generator_optimizer_params"])
        dis_optimizer = tf.keras.optimizers.Adam(
            **config["discriminator_optimizer_params"])

    trainer.compile(
        gen_model=generator,
        dis_model=discriminator,
        gen_optimizer=gen_optimizer,
        dis_optimizer=dis_optimizer,
    )

    # start training
    try:
        trainer.fit(
            train_dataset,
            valid_dataset,
            saved_path=os.path.join(config["outdir"], "checkpoints/"),
            resume=args.resume,
        )
    except KeyboardInterrupt:
        trainer.save_checkpoint()
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
コード例 #7
0
def main():
    """Run melgan decoding from folder."""
    parser = argparse.ArgumentParser(
        description="Generate Audio from melspectrogram with trained melgan "
        "(See detail in example/melgan/decode_melgan.py)."
    )
    parser.add_argument(
        "--rootdir",
        default=None,
        type=str,
        required=True,
        help="directory including ids/durations files.",
    )
    parser.add_argument(
        "--outdir", type=str, required=True, help="directory to save generated speech."
    )
    parser.add_argument(
        "--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
    )
    parser.add_argument(
        "--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."
    )
    parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        required=True,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)",
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    if config["format"] == "npy":
        mel_query = "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"
        mel_load_fn = np.load
    else:
        raise ValueError("Only npy is supported.")

    # define data-loader
    dataset = MelDataset(
        root_dir=args.rootdir,
        mel_query=mel_query,
        mel_load_fn=mel_load_fn,
    )
    dataset = dataset.create(batch_size=args.batch_size)

    # define model and load checkpoint
    melgan = TFMelGANGenerator(
        config=MelGANGeneratorConfig(**config["generator_params"]), name="melgan"
    )
    melgan._build()
    melgan.load_weights(args.checkpoint)

    for data in tqdm(dataset, desc="[Decoding]"):
        utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]
        # melgan inference.
        generated_audios = melgan(mels)

        # convert to numpy.
        generated_audios = generated_audios.numpy()  # [B, T]

        # save to outdir
        for i, audio in enumerate(generated_audios):
            utt_id = utt_ids[i].numpy().decode("utf-8")
            sf.write(
                os.path.join(args.outdir, f"{utt_id}.wav"),
                audio[: mel_lengths[i].numpy() * config["hop_size"]],
                config["sampling_rate"],
                "PCM_16",
            )
コード例 #8
0
from tensorflow_tts.configs import Tacotron2Config
from tensorflow_tts.configs import MelGANGeneratorConfig
from tensorflow_tts.models import TFTacotron2
from tensorflow_tts.models import TFMelGANGenerator
from tensorflow_tts.models import TFMBMelGANGenerator
from tensorflow_tts.configs import MultiBandMelGANGeneratorConfig
from tensorflow_tts.inference import AutoProcessor

from IPython.display import Audio
print(tf.__version__) # 2.5.0-dev20210103

# initialize melgan model 正常的发音
with open( config_lp.multiband_melgan_baker ) as f:
    melgan_config = yaml.load(f, Loader=yaml.Loader)
melgan_config = MelGANGeneratorConfig(**melgan_config["multiband_melgan_generator_params"])
melgan = TFMelGANGenerator(config=melgan_config, name='mb_melgan')
melgan._build()
melgan.load_weights(config_lp.multiband_melgan_pretrained_path)

# # Concrete Function
# melgan_concrete_function = melgan.inference_tflite.get_concrete_function()
# converter = tf.lite.TFLiteConverter.from_concrete_functions(
#     [melgan_concrete_function]
# )
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
#                                        tf.lite.OpsSet.SELECT_TF_OPS]
# tflite_model = converter.convert()
# # Save the TF Lite model.
# with open('./gen_model/melgan_baker.tflite', 'wb') as f:
#   f.write(tflite_model)