示例#1
0
def main():
    """Run training process."""
    parser = argparse.ArgumentParser(
        description="Train FastSpeech (See detail in tensorflow_tts/bin/train-fastspeech.py)"
    )
    parser.add_argument(
        "--train-dir",
        default=None,
        type=str,
        help="directory including training data. ",
    )
    parser.add_argument(
        "--dev-dir",
        default=None,
        type=str,
        help="directory including development data. ",
    )
    parser.add_argument(
        "--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
    )
    parser.add_argument(
        "--outdir", type=str, required=True, help="directory to save checkpoints."
    )
    parser.add_argument(
        "--config", type=str, required=True, help="yaml format configuration file."
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        nargs="?",
        help='checkpoint file path to resume training. (default="")',
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    parser.add_argument(
        "--mixed_precision",
        default=0,
        type=int,
        help="using mixed precision for generator or not.",
    )
    parser.add_argument(
        "--pretrained",
        default="",
        type=str,
        nargs="?",
        help='pretrained weights .h5 file to load weights from. Auto-skips non-matching layers',
    )
    args = parser.parse_args()

    # return strategy
    STRATEGY = return_strategy()

    # set mixed precision config
    if args.mixed_precision == 1:
        tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

    args.mixed_precision = bool(args.mixed_precision)
    args.use_norm = bool(args.use_norm)

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            stream=sys.stdout,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            stream=sys.stdout,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            stream=sys.stdout,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # check arguments
    if args.train_dir is None:
        raise ValueError("Please specify --train-dir")
    if args.dev_dir is None:
        raise ValueError("Please specify --valid-dir")

    # load and save config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))
    config["version"] = tensorflow_tts.__version__

    # get dataset
    if config["remove_short_samples"]:
        mel_length_threshold = config["mel_length_threshold"]
    else:
        mel_length_threshold = 0

    if config["format"] == "npy":
        charactor_query = "*-ids.npy"
        mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
        charactor_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("Only npy are supported.")

    train_dataset = CharactorMelDataset(
        dataset=config["tacotron2_params"]["dataset"],
        root_dir=args.train_dir,
        charactor_query=charactor_query,
        mel_query=mel_query,
        charactor_load_fn=charactor_load_fn,
        mel_load_fn=mel_load_fn,
        mel_length_threshold=mel_length_threshold,
        reduction_factor=config["tacotron2_params"]["reduction_factor"],
        use_fixed_shapes=config["use_fixed_shapes"],
    )

    # update max_mel_length and max_char_length to config
    config.update({"max_mel_length": int(train_dataset.max_mel_length)})
    config.update({"max_char_length": int(train_dataset.max_char_length)})

    with open(os.path.join(args.outdir, "config.yml"), "w") as f:
        yaml.dump(config, f, Dumper=yaml.Dumper)
    for key, value in config.items():
        logging.info(f"{key} = {value}")

    train_dataset = train_dataset.create(
        is_shuffle=config["is_shuffle"],
        allow_cache=config["allow_cache"],
        batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
    )

    valid_dataset = CharactorMelDataset(
        dataset=config["tacotron2_params"]["dataset"],
        root_dir=args.dev_dir,
        charactor_query=charactor_query,
        mel_query=mel_query,
        charactor_load_fn=charactor_load_fn,
        mel_load_fn=mel_load_fn,
        mel_length_threshold=mel_length_threshold,
        reduction_factor=config["tacotron2_params"]["reduction_factor"],
        use_fixed_shapes=False,  # don't need apply fixed shape for evaluation.
    ).create(
        is_shuffle=config["is_shuffle"],
        allow_cache=config["allow_cache"],
        batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
    )

    # define trainer
    trainer = Tacotron2Trainer(
        config=config,
        strategy=STRATEGY,
        steps=0,
        epochs=0,
        is_mixed_precision=args.mixed_precision,
    )

    with STRATEGY.scope():
        # define model.
        tacotron_config = Tacotron2Config(**config["tacotron2_params"])
        tacotron2 = TFTacotron2(config=tacotron_config, training=True, name="tacotron2")
        tacotron2._build()
        tacotron2.summary()
        
        if len(args.pretrained) > 1:
            tacotron2.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
            logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")




        # AdamW for tacotron2
        learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
            initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
            decay_steps=config["optimizer_params"]["decay_steps"],
            end_learning_rate=config["optimizer_params"]["end_learning_rate"],
        )

        learning_rate_fn = WarmUp(
            initial_learning_rate=config["optimizer_params"]["initial_learning_rate"],
            decay_schedule_fn=learning_rate_fn,
            warmup_steps=int(
                config["train_max_steps"]
                * config["optimizer_params"]["warmup_proportion"]
            ),
        )

        optimizer = AdamWeightDecay(
            learning_rate=learning_rate_fn,
            weight_decay_rate=config["optimizer_params"]["weight_decay"],
            beta_1=0.9,
            beta_2=0.98,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        )

        _ = optimizer.iterations

    # compile trainer
    trainer.compile(model=tacotron2, optimizer=optimizer)

    # start training
    try:
        trainer.fit(
            train_dataset,
            valid_dataset,
            saved_path=os.path.join(config["outdir"], "checkpoints/"),
            resume=args.resume,
        )
    except KeyboardInterrupt:
        trainer.save_checkpoint()
        logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
示例#2
0
def main():
    """Running decode tacotron-2 mel-spectrogram."""
    parser = argparse.ArgumentParser(
        description=
        "Decode mel-spectrogram from folder ids with trained Tacotron-2 "
        "(See detail in tensorflow_tts/example/tacotron2/decode_tacotron2.py)."
    )
    parser.add_argument(
        "--rootdir",
        default=None,
        type=str,
        required=True,
        help="directory including ids/durations files.",
    )
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save generated speech.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument("--use-norm",
                        default=1,
                        type=int,
                        help="usr norm-mels for train or raw.")
    parser.add_argument("--batch-size",
                        default=8,
                        type=int,
                        help="batch size.")
    parser.add_argument("--win-front", default=3, type=int, help="win-front.")
    parser.add_argument("--win-back", default=3, type=int, help="win-front.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        required=True,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)",
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    if config["format"] == "npy":
        char_query = "*-ids.npy"
        mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
        char_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("Only npy is supported.")

    # define data-loader
    dataset = CharactorMelDataset(
        dataset=config["tacotron2_params"]["dataset"],
        root_dir=args.rootdir,
        charactor_query=char_query,
        mel_query=mel_query,
        charactor_load_fn=char_load_fn,
        mel_load_fn=mel_load_fn,
        reduction_factor=config["tacotron2_params"]["reduction_factor"])
    dataset = dataset.create(allow_cache=True, batch_size=args.batch_size)

    # define model and load checkpoint
    tacotron2 = TFTacotron2(
        config=Tacotron2Config(**config["tacotron2_params"]),
        name="tacotron2",
    )
    tacotron2._build()  # build model to be able load_weights.
    tacotron2.load_weights(args.checkpoint)

    # setup window
    tacotron2.setup_window(win_front=args.win_front, win_back=args.win_back)

    for data in tqdm(dataset, desc="[Decoding]"):
        utt_ids = data["utt_ids"]
        utt_ids = utt_ids.numpy()

        # tacotron2 inference.
        (
            mel_outputs,
            post_mel_outputs,
            stop_outputs,
            alignment_historys,
        ) = tacotron2.inference(
            input_ids=data["input_ids"],
            input_lengths=data["input_lengths"],
            speaker_ids=data["speaker_ids"],
        )

        # convert to numpy
        post_mel_outputs = post_mel_outputs.numpy()

        for i, post_mel_output in enumerate(post_mel_outputs):
            stop_token = tf.math.round(tf.nn.sigmoid(stop_outputs[i]))  # [T]
            real_length = tf.math.reduce_sum(
                tf.cast(tf.math.equal(stop_token, 0.0), tf.int32), -1)
            post_mel_output = post_mel_output[:real_length, :]

            saved_name = utt_ids[i].decode("utf-8")

            # save D to folder.
            np.save(
                os.path.join(args.outdir, f"{saved_name}-norm-feats.npy"),
                post_mel_output.astype(np.float32),
                allow_pickle=False,
            )
示例#3
0
def main():
    """Running extract tacotron-2 durations."""
    parser = argparse.ArgumentParser(
        description="Extract durations from charactor with trained Tacotron-2 "
        "(See detail in tensorflow_tts/example/tacotron-2/extract_duration.py)."
    )
    parser.add_argument(
        "--rootdir",
        default=None,
        type=str,
        required=True,
        help="directory including ids/durations files.",
    )
    parser.add_argument("--outdir",
                        type=str,
                        required=True,
                        help="directory to save generated speech.")
    parser.add_argument("--checkpoint",
                        type=str,
                        required=True,
                        help="checkpoint file to be loaded.")
    parser.add_argument("--use-norm",
                        default=1,
                        type=int,
                        help="usr norm-mels for train or raw.")
    parser.add_argument("--batch-size",
                        default=8,
                        type=int,
                        help="batch size.")
    parser.add_argument("--win-front", default=2, type=int, help="win-front.")
    parser.add_argument("--win-back", default=2, type=int, help="win-front.")
    parser.add_argument("--save-alignment",
                        default=0,
                        type=int,
                        help="save-alignment.")
    parser.add_argument(
        "--config",
        default=None,
        type=str,
        required=True,
        help="yaml format configuration file. if not explicitly provided, "
        "it will be searched in the checkpoint directory. (default=None)",
    )
    parser.add_argument(
        "--verbose",
        type=int,
        default=1,
        help="logging level. higher is more logging. (default=1)",
    )
    args = parser.parse_args()

    # set logger
    if args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check directory existence
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load config
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config.update(vars(args))

    if config["format"] == "npy":
        char_query = "*-ids.npy"
        mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
        char_load_fn = np.load
        mel_load_fn = np.load
    else:
        raise ValueError("Only npy is supported.")

    # define data-loader
    dataset = CharactorMelDataset(
        root_dir=args.rootdir,
        charactor_query=char_query,
        mel_query=mel_query,
        charactor_load_fn=char_load_fn,
        mel_load_fn=mel_load_fn,
        return_utt_id=True,
        return_guided_attention=False,
    )
    dataset = dataset.create(allow_cache=True, batch_size=args.batch_size)

    # define model and load checkpoint
    tacotron2 = TFTacotron2(
        config=Tacotron2Config(**config["tacotron2_params"]),
        training=True,  # enable teacher forcing mode.
        name="tacotron2",
    )
    tacotron2._build()  # build model to be able load_weights.
    tacotron2.load_weights(args.checkpoint)

    for data in tqdm(dataset, desc="[Extract Duration]"):
        utt_id, charactor, char_length, mel, mel_length = data
        utt_id = utt_id.numpy()

        # tacotron2 inference.
        mel_outputs, post_mel_outputs, stop_outputs, alignment_historys = tacotron2(
            charactor,
            char_length,
            speaker_ids=tf.zeros(shape=[tf.shape(charactor)[0]]),
            mel_outputs=mel,
            mel_lengths=mel_length,
            use_window_mask=True,
            win_front=args.win_front,
            win_back=args.win_back,
            training=True,
        )

        # convert to numpy
        alignment_historys = alignment_historys.numpy()

        for i, alignment in enumerate(alignment_historys):
            real_char_length = (char_length[i].numpy() - 1
                                )  # minus 1 because char have eos tokens.
            real_mel_length = mel_length[i].numpy()
            alignment = alignment[:real_char_length, :real_mel_length]
            d = get_duration_from_alignment(alignment)  # [max_char_len]

            saved_name = utt_id[i].decode("utf-8")

            # check a length compatible
            assert (
                len(d) == real_char_length
            ), f"different between len_char and len_durations, {len(d)} and {real_char_length}"

            assert (
                np.sum(d) == real_mel_length
            ), f"different between sum_durations and len_mel, {np.sum(d)} and {real_mel_length}"

            # save D to folder.
            np.save(
                os.path.join(args.outdir, f"{saved_name}-durations.npy"),
                d.astype(np.int32),
                allow_pickle=False,
            )

            # save alignment to debug.
            if args.save_alignment == 1:
                figname = os.path.join(args.outdir,
                                       f"{saved_name}_alignment.png")
                fig = plt.figure(figsize=(8, 6))
                ax = fig.add_subplot(111)
                ax.set_title(f"Alignment of {saved_name}")
                im = ax.imshow(alignment,
                               aspect="auto",
                               origin="lower",
                               interpolation="none")
                fig.colorbar(im, ax=ax)
                xlabel = "Decoder timestep"
                plt.xlabel(xlabel)
                plt.ylabel("Encoder timestep")
                plt.tight_layout()
                plt.savefig(figname)
                plt.close()