예제 #1
0
def test_fastspeech_trainable(num_hidden_layers, n_speakers):
    config = FastSpeechConfig(num_hidden_layers=num_hidden_layers, n_speakers=n_speakers)

    fastspeech = TFFastSpeech(config, name='fastspeech')
    optimizer = tf.keras.optimizers.Adam(lr=0.001)

    # fake inputs
    input_ids = tf.convert_to_tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], tf.int32)
    attention_mask = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)
    speaker_ids = tf.convert_to_tensor([0], tf.int32)
    duration_gts = tf.convert_to_tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], tf.int32)

    mel_gts = tf.random.uniform(shape=[1, 10, 80], dtype=tf.float32)

    @tf.function
    def one_step_training():
        with tf.GradientTape() as tape:
            mel_outputs_before, _, duration_outputs = fastspeech(
                input_ids, attention_mask, speaker_ids, duration_gts, training=True)
            duration_loss = tf.keras.losses.MeanSquaredError()(duration_gts, duration_outputs)
            mel_loss = tf.keras.losses.MeanSquaredError()(mel_gts, mel_outputs_before)
            loss = duration_loss + mel_loss
        gradients = tape.gradient(loss, fastspeech.trainable_variables)
        optimizer.apply_gradients(zip(gradients, fastspeech.trainable_variables))

        tf.print(loss)

    import time
    for i in range(2):
        if i == 1:
            start = time.time()
        one_step_training()
    print(time.time() - start)
예제 #2
0
def test_fastspeech_resize_positional_embeddings(new_size):
    config = FastSpeechConfig()
    fastspeech = TFFastSpeech(config, name="fastspeech")
    fastspeech._build()
    fastspeech.save_weights("./test.h5")
    fastspeech.resize_positional_embeddings(new_size)
    fastspeech.load_weights("./test.h5", by_name=True, skip_mismatch=True)
예제 #3
0
def get_model():
    with open( get_weight_path('fastspeech_config.yml') ) as f:
        config = yaml.load(f, Loader=yaml.Loader)

    config = FastSpeechConfig(**config['fastspeech_params'])
    fastspeech = TFFastSpeech(config=config, name='fastspeech')
    fastspeech._build()
    fastspeech.load_weights( get_weight_path('fastspeech-150k.h5') )

    return fastspeech
예제 #4
0
def main():
    """Run fastspeech decoding from folder."""
    parser = argparse.ArgumentParser(
        description=
        "Decode soft-mel features from charactor with trained FastSpeech "
        "(See detail in examples/fastspeech/decode_fastspeech.py).")
    parser.add_argument(
        "--rootdir",
        default=None,
        type=str,
        required=True,
        help="directory including ids/durations files.",
    )
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save generated speech.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        required=True,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)",
    )
    parser.add_argument(
        "--batch-size",
        default=8,
        type=int,
        required=False,
        help="Batch size for inference.",
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    if config["format"] == "npy":
        char_query = "*-ids.npy"
        char_load_fn = np.load
    else:
        raise ValueError("Only npy is supported.")

    # define data-loader
    dataset = CharactorDataset(
        root_dir=args.rootdir,
        charactor_query=char_query,
        charactor_load_fn=char_load_fn,
    )
    dataset = dataset.create(batch_size=args.batch_size)

    # define model and load checkpoint
    fastspeech = TFFastSpeech(
        config=FastSpeechConfig(**config["fastspeech_params"]),
        name="fastspeech")
    fastspeech._build()
    fastspeech.load_weights(args.checkpoint)

    for data in tqdm(dataset, desc="Decoding"):
        utt_ids = data["utt_ids"]
        char_ids = data["input_ids"]

        # fastspeech inference.
        masked_mel_before, masked_mel_after, duration_outputs = fastspeech.inference(
            char_ids,
            speaker_ids=tf.zeros(shape=[tf.shape(char_ids)[0]],
                                 dtype=tf.int32),
            speed_ratios=tf.ones(shape=[tf.shape(char_ids)[0]],
                                 dtype=tf.float32),
        )

        # convert to numpy
        masked_mel_befores = masked_mel_before.numpy()
        masked_mel_afters = masked_mel_after.numpy()

        for (utt_id, mel_before, mel_after,
             durations) in zip(utt_ids, masked_mel_befores, masked_mel_afters,
                               duration_outputs):
            # real len of mel predicted
            real_length = durations.numpy().sum()
            utt_id = utt_id.numpy().decode("utf-8")
            # save to folder.
            np.save(
                os.path.join(args.outdir, f"{utt_id}-fs-before-feats.npy"),
                mel_before[:real_length, :].astype(np.float32),
                allow_pickle=False,
            )
            np.save(
                os.path.join(args.outdir, f"{utt_id}-fs-after-feats.npy"),
                mel_after[:real_length, :].astype(np.float32),
                allow_pickle=False,
            )