예제 #1
0
def setup_model(config):
    print(" > Using model: {}".format(config.model))
    # fetch the right model implementation.
    if "base_model" in config and config["base_model"] is not None:
        MyModel = find_module("TTS.tts.models", config.base_model.lower())
    else:
        MyModel = find_module("TTS.tts.models", config.model.lower())
    # define set of characters used by the model
    if config.characters is not None:
        # set characters from config
        if hasattr(MyModel, "make_symbols"):
            symbols = MyModel.make_symbols(config)
        else:
            symbols, phonemes = make_symbols(**config.characters)
    else:
        from TTS.tts.utils.text.symbols import (  # pylint: disable=import-outside-toplevel
            phonemes, symbols,
        )

        if config.use_phonemes:
            symbols = phonemes  # noqa: F811
        # use default characters and assign them to config
        config.characters = parse_symbols()
    # consider special `blank` character if `add_blank` is set True
    num_chars = len(symbols) + getattr(config, "add_blank", False)
    config.num_chars = num_chars
    # compatibility fix
    if "model_params" in config:
        config.model_params.num_chars = num_chars
    if "model_args" in config:
        config.model_args.num_chars = num_chars
    model = MyModel(config)
    return model
예제 #2
0
def process_args(args):
    """Process parsed comand line arguments.

    Args:
        args (argparse.Namespace or dict like): Parsed input arguments.

    Returns:
        c (TTS.utils.io.AttrDict): Config paramaters.
        out_path (str): Path to save models and logging.
        audio_path (str): Path to save generated test audios.
        c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
            logging to the console.
        tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
            the TensorBoard loggind.
    """
    if isinstance(args, tuple):
        args, coqpit_overrides = args
    if args.continue_path:
        # continue a previous training from its output folder
        experiment_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, "config.json")
        args.restore_path, best_model = get_last_checkpoint(args.continue_path)
        if not args.best_path:
            args.best_path = best_model
    # setup output paths and read configs
    config = load_config(args.config_path)
    # override values from command-line args
    config.parse_known_args(coqpit_overrides, relaxed_parser=True)
    if config.mixed_precision:
        print("   >  Mixed precision mode is ON")
    experiment_path = args.continue_path
    if not experiment_path:
        experiment_path = create_experiment_folder(config.output_path,
                                                   config.run_name, args.debug)
    audio_path = os.path.join(experiment_path, "test_audios")
    # setup rank 0 process in distributed training
    tb_logger = None
    if args.rank == 0:
        os.makedirs(audio_path, exist_ok=True)
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        # if model characters are not set in the config file
        # save the default set to the config file for future
        # compatibility.
        if config.has("characters_config"):
            used_characters = parse_symbols()
            new_fields["characters"] = used_characters
        copy_model_files(config, experiment_path, new_fields)
        os.chmod(audio_path, 0o775)
        os.chmod(experiment_path, 0o775)
        tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
        # write model desc to tensorboard
        tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
                              0)
    c_logger = ConsoleLogger()
    return config, experiment_path, audio_path, c_logger, tb_logger
예제 #3
0
파일: base_tts.py 프로젝트: stjordanis/TTS
    def get_characters(config: Coqpit) -> str:
        # TODO: implement CharacterProcessor
        if config.characters is not None:
            symbols, phonemes = make_symbols(**config.characters)
        else:
            from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols

            config.characters = CharactersConfig(**parse_symbols())
        model_characters = phonemes if config.use_phonemes else symbols
        num_chars = len(model_characters) + getattr(config, "add_blank", False)
        return model_characters, config, num_chars
예제 #4
0
    def get_characters(config: Coqpit):
        if config.characters is not None:
            symbols = Vits.make_symbols(config)
        else:
            from TTS.tts.utils.text.symbols import (  # pylint: disable=import-outside-toplevel
                parse_symbols, phonemes, symbols,
            )

            config.characters = parse_symbols()
            if config.use_phonemes:
                symbols = phonemes  # noqa: F811
        num_chars = len(symbols) + getattr(config, "add_blank", False)
        return symbols, config, num_chars
예제 #5
0
    def get_characters(config: Coqpit) -> str:
        # TODO: implement CharacterProcessor
        if config.characters is not None:
            symbols, phonemes = make_symbols(**config.characters)
        else:
            from TTS.tts.utils.text.symbols import (  # pylint: disable=import-outside-toplevel
                parse_symbols,
                phonemes,
                symbols,
            )

            config.characters = parse_symbols()
        model_characters = phonemes if config.use_phonemes else symbols
        return model_characters, config
예제 #6
0
def process_args(args, model_class):
    """Process parsed comand line arguments based on model class (tts or vocoder).

    Args:
        args (argparse.Namespace or dict like): Parsed input arguments.
        model_type (str): Model type used to check config parameters and setup
            the TensorBoard logger. One of ['tts', 'vocoder'].

    Raises:
        ValueError: If `model_type` is not one of implemented choices.

    Returns:
        c (TTS.utils.io.AttrDict): Config paramaters.
        out_path (str): Path to save models and logging.
        audio_path (str): Path to save generated test audios.
        c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
            logging to the console.
        tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
            the TensorBoard loggind.
    """
    if args.continue_path:
        args.output_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, "config.json")
        args.restore_path, best_model = get_last_checkpoint(args.continue_path)
        if not args.best_path:
            args.best_path = best_model

    # setup output paths and read configs
    c = load_config(args.config_path)
    _ = os.path.dirname(os.path.realpath(__file__))

    if 'mixed_precision' in c and c.mixed_precision:
        print("   >  Mixed precision mode is ON")

    out_path = args.continue_path
    if not out_path:
        out_path = create_experiment_folder(c.output_path, c.run_name,
                                            args.debug)

    audio_path = os.path.join(out_path, "test_audios")

    c_logger = ConsoleLogger()
    tb_logger = None

    if args.rank == 0:
        os.makedirs(audio_path, exist_ok=True)
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        # if model characters are not set in the config file
        # save the default set to the config file for future
        # compatibility.
        if model_class == 'tts' and 'characters' not in c:
            used_characters = parse_symbols()
            new_fields['characters'] = used_characters
        copy_model_files(c, args.config_path, out_path, new_fields)
        os.chmod(audio_path, 0o775)
        os.chmod(out_path, 0o775)

        log_path = out_path

        tb_logger = TensorboardLogger(log_path, model_name=model_class.upper())

        # write model desc to tensorboard
        tb_logger.tb_add_text("model-description", c["run_description"], 0)

    return c, out_path, audio_path, c_logger, tb_logger
예제 #7
0
파일: training.py 프로젝트: stjordanis/TTS
def process_args(args, config=None):
    """Process parsed comand line arguments and initialize the config if not provided.
    Args:
        args (argparse.Namespace or dict like): Parsed input arguments.
        config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
    Returns:
        c (TTS.utils.io.AttrDict): Config paramaters.
        out_path (str): Path to save models and logging.
        audio_path (str): Path to save generated test audios.
        c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
            logging to the console.
        dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
    TODO:
        - Interactive config definition.
    """
    if isinstance(args, tuple):
        args, coqpit_overrides = args
    if args.continue_path:
        # continue a previous training from its output folder
        experiment_path = args.continue_path
        args.config_path = os.path.join(args.continue_path, "config.json")
        args.restore_path, best_model = get_last_checkpoint(args.continue_path)
        if not args.best_path:
            args.best_path = best_model
    # init config if not already defined
    if config is None:
        if args.config_path:
            # init from a file
            config = load_config(args.config_path)
        else:
            # init from console args
            from TTS.config.shared_configs import BaseTrainingConfig  # pylint: disable=import-outside-toplevel

            config_base = BaseTrainingConfig()
            config_base.parse_known_args(coqpit_overrides)
            config = register_config(config_base.model)()
    # override values from command-line args
    config.parse_known_args(coqpit_overrides, relaxed_parser=True)
    experiment_path = args.continue_path
    if not experiment_path:
        experiment_path = get_experiment_folder_path(config.output_path,
                                                     config.run_name)
    audio_path = os.path.join(experiment_path, "test_audios")
    config.output_log_path = experiment_path
    # setup rank 0 process in distributed training
    dashboard_logger = None
    if args.rank == 0:
        new_fields = {}
        if args.restore_path:
            new_fields["restore_path"] = args.restore_path
        new_fields["github_branch"] = get_git_branch()
        # if model characters are not set in the config file
        # save the default set to the config file for future
        # compatibility.
        if config.has("characters") and config.characters is None:
            used_characters = parse_symbols()
            new_fields["characters"] = used_characters
        copy_model_files(config, experiment_path, new_fields)
        dashboard_logger = init_dashboard_logger(config)
    c_logger = ConsoleLogger()
    return config, experiment_path, audio_path, c_logger, dashboard_logger