예제 #1
0
def main(cmd=None):
    """TTS training

    Example:

        % python tts_train.py asr --print_config --optim adadelta
        % python tts_train.py --config conf/train_asr.yaml
    """
    TTSTask.main(cmd=cmd)
예제 #2
0
def config_file(tmp_path: Path, token_list):
    # Write default configuration file
    TTSTask.main(cmd=[
        "--dry_run",
        "true",
        "--output_dir",
        str(tmp_path),
        "--token_list",
        str(token_list),
        "--token_type",
        "char",
        "--cleaner",
        "none",
        "--g2p",
        "none",
        "--normalize",
        "none",
    ])
    return tmp_path / "config.yaml"
예제 #3
0
    def __init__(
        self,
        train_config: Optional[Union[Path, str]],
        model_file: Optional[Union[Path, str]] = None,
        threshold: float = 0.5,
        minlenratio: float = 0.0,
        maxlenratio: float = 10.0,
        use_teacher_forcing: bool = False,
        use_att_constraint: bool = False,
        backward_window: int = 1,
        forward_window: int = 3,
        speed_control_alpha: float = 1.0,
        vocoder_conf: dict = None,
        dtype: str = "float32",
        device: str = "cpu",
    ):
        assert check_argument_types()

        model, train_args = TTSTask.build_model_from_file(
            train_config, model_file, device)
        model.to(dtype=getattr(torch, dtype)).eval()
        self.device = device
        self.dtype = dtype
        self.train_args = train_args
        self.model = model
        self.tts = model.tts
        self.normalize = model.normalize
        self.feats_extract = model.feats_extract
        self.duration_calculator = DurationCalculator()
        self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False)
        self.use_teacher_forcing = use_teacher_forcing

        logging.info(f"Normalization:\n{self.normalize}")
        logging.info(f"TTS:\n{self.tts}")

        decode_config = {}
        if isinstance(self.tts, (Tacotron2, Transformer)):
            decode_config.update({
                "threshold": threshold,
                "maxlenratio": maxlenratio,
                "minlenratio": minlenratio,
            })
        if isinstance(self.tts, Tacotron2):
            decode_config.update({
                "use_att_constraint": use_att_constraint,
                "forward_window": forward_window,
                "backward_window": backward_window,
            })
        if isinstance(self.tts, (FastSpeech, FastSpeech2)):
            decode_config.update({"alpha": speed_control_alpha})
        decode_config.update({"use_teacher_forcing": use_teacher_forcing})

        self.decode_config = decode_config

        if vocoder_conf is None:
            vocoder_conf = {}
        if self.feats_extract is not None:
            vocoder_conf.update(self.feats_extract.get_parameters())
        if ("n_fft" in vocoder_conf and "n_shift" in vocoder_conf
                and "fs" in vocoder_conf):
            self.spc2wav = Spectrogram2Waveform(**vocoder_conf)
            logging.info(f"Vocoder: {self.spc2wav}")
        else:
            self.spc2wav = None
            logging.info(
                "Vocoder is not used because vocoder_conf is not sufficient")
예제 #4
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    threshold: float,
    minlenratio: float,
    maxlenratio: float,
    use_teacher_forcing: bool,
    use_att_constraint: bool,
    backward_window: int,
    forward_window: int,
    speed_control_alpha: float,
    allow_variable_data_keys: bool,
    vocoder_conf: dict,
):
    """Perform TTS model decoding."""
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build model
    text2speech = Text2Speech(
        train_config=train_config,
        model_file=model_file,
        threshold=threshold,
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
        use_teacher_forcing=use_teacher_forcing,
        use_att_constraint=use_att_constraint,
        backward_window=backward_window,
        forward_window=forward_window,
        speed_control_alpha=speed_control_alpha,
        vocoder_conf=vocoder_conf,
        dtype=dtype,
        device=device,
    )

    # 3. Build data-iterator
    if not text2speech.use_speech:
        data_path_and_name_and_type = list(
            filter(lambda x: x[1] != "speech", data_path_and_name_and_type))
    loader = TTSTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args,
                                                  False),
        collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 6. Start for-loop
    output_dir = Path(output_dir)
    (output_dir / "norm").mkdir(parents=True, exist_ok=True)
    (output_dir / "denorm").mkdir(parents=True, exist_ok=True)
    (output_dir / "speech_shape").mkdir(parents=True, exist_ok=True)
    (output_dir / "wav").mkdir(parents=True, exist_ok=True)
    (output_dir / "att_ws").mkdir(parents=True, exist_ok=True)
    (output_dir / "probs").mkdir(parents=True, exist_ok=True)
    (output_dir / "durations").mkdir(parents=True, exist_ok=True)
    (output_dir / "focus_rates").mkdir(parents=True, exist_ok=True)

    # Lazy load to avoid the backend error
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator

    with NpyScpWriter(
            output_dir / "norm",
            output_dir / "norm/feats.scp",
    ) as norm_writer, NpyScpWriter(
            output_dir / "denorm",
            output_dir / "denorm/feats.scp") as denorm_writer, open(
                output_dir / "speech_shape/speech_shape",
                "w") as shape_writer, open(output_dir / "durations/durations",
                                           "w") as duration_writer, open(
                                               output_dir /
                                               "focus_rates/focus_rates",
                                               "w") as focus_rate_writer:
        for idx, (keys, batch) in enumerate(loader, 1):
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert _bs == 1, _bs

            # Change to single sequence and remove *_length
            # because inference() requires 1-seq, not mini-batch.
            batch = {
                k: v[0]
                for k, v in batch.items() if not k.endswith("_lengths")
            }

            start_time = time.perf_counter()
            wav, outs, outs_denorm, probs, att_ws, duration, focus_rate = text2speech(
                **batch)

            key = keys[0]
            insize = next(iter(batch.values())).size(0) + 1
            logging.info("inference speed = {:.1f} frames / sec.".format(
                int(outs.size(0)) / (time.perf_counter() - start_time)))
            logging.info(f"{key} (size:{insize}->{outs.size(0)})")
            if outs.size(0) == insize * maxlenratio:
                logging.warning(
                    f"output length reaches maximum length ({key}).")

            norm_writer[key] = outs.cpu().numpy()
            shape_writer.write(f"{key} " + ",".join(map(str, outs.shape)) +
                               "\n")

            denorm_writer[key] = outs_denorm.cpu().numpy()

            if duration is not None:
                # Save duration and fucus rates
                duration_writer.write(f"{key} " +
                                      " ".join(map(str,
                                                   duration.cpu().numpy())) +
                                      "\n")
                focus_rate_writer.write(f"{key} {float(focus_rate):.5f}\n")

                # Plot attention weight
                att_ws = att_ws.cpu().numpy()

                if att_ws.ndim == 2:
                    att_ws = att_ws[None][None]
                elif att_ws.ndim != 4:
                    raise RuntimeError(
                        f"Must be 2 or 4 dimension: {att_ws.ndim}")

                w, h = plt.figaspect(att_ws.shape[0] / att_ws.shape[1])
                fig = plt.Figure(figsize=(
                    w * 1.3 * min(att_ws.shape[0], 2.5),
                    h * 1.3 * min(att_ws.shape[1], 2.5),
                ))
                fig.suptitle(f"{key}")
                axes = fig.subplots(att_ws.shape[0], att_ws.shape[1])
                if len(att_ws) == 1:
                    axes = [[axes]]
                for ax, att_w in zip(axes, att_ws):
                    for ax_, att_w_ in zip(ax, att_w):
                        ax_.imshow(att_w_.astype(np.float32), aspect="auto")
                        ax_.set_xlabel("Input")
                        ax_.set_ylabel("Output")
                        ax_.xaxis.set_major_locator(MaxNLocator(integer=True))
                        ax_.yaxis.set_major_locator(MaxNLocator(integer=True))

                fig.set_tight_layout({"rect": [0, 0.03, 1, 0.95]})
                fig.savefig(output_dir / f"att_ws/{key}.png")
                fig.clf()

            if probs is not None:
                # Plot stop token prediction
                probs = probs.cpu().numpy()

                fig = plt.Figure()
                ax = fig.add_subplot(1, 1, 1)
                ax.plot(probs)
                ax.set_title(f"{key}")
                ax.set_xlabel("Output")
                ax.set_ylabel("Stop probability")
                ax.set_ylim(0, 1)
                ax.grid(which="both")

                fig.set_tight_layout(True)
                fig.savefig(output_dir / f"probs/{key}.png")
                fig.clf()

            # TODO(kamo): Write scp
            if wav is not None:
                sf.write(f"{output_dir}/wav/{key}.wav", wav.numpy(),
                         text2speech.fs, "PCM_16")

    # remove duration related files if attention is not provided
    if att_ws is None:
        shutil.rmtree(output_dir / "att_ws")
        shutil.rmtree(output_dir / "durations")
        shutil.rmtree(output_dir / "focus_rates")
    if probs is None:
        shutil.rmtree(output_dir / "probs")
예제 #5
0
    def __init__(
        self,
        train_config: Union[Path, str] = None,
        model_file: Union[Path, str] = None,
        threshold: float = 0.5,
        minlenratio: float = 0.0,
        maxlenratio: float = 10.0,
        use_teacher_forcing: bool = False,
        use_att_constraint: bool = False,
        backward_window: int = 1,
        forward_window: int = 3,
        speed_control_alpha: float = 1.0,
        noise_scale: float = 0.667,
        noise_scale_dur: float = 0.8,
        vocoder_config: Union[Path, str] = None,
        vocoder_file: Union[Path, str] = None,
        dtype: str = "float32",
        device: str = "cpu",
        seed: int = 777,
        always_fix_seed: bool = False,
    ):
        """Initialize Text2Speech module."""
        assert check_argument_types()

        # setup model
        model, train_args = TTSTask.build_model_from_file(
            train_config, model_file, device)
        model.to(dtype=getattr(torch, dtype)).eval()
        self.device = device
        self.dtype = dtype
        self.train_args = train_args
        self.model = model
        self.tts = model.tts
        self.normalize = model.normalize
        self.feats_extract = model.feats_extract
        self.duration_calculator = DurationCalculator()
        self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False)
        self.use_teacher_forcing = use_teacher_forcing
        self.seed = seed
        self.always_fix_seed = always_fix_seed
        self.vocoder = None
        if self.tts.require_vocoder:
            vocoder = TTSTask.build_vocoder_from_file(vocoder_config,
                                                      vocoder_file, model,
                                                      device)
            if isinstance(vocoder, torch.nn.Module):
                vocoder.to(dtype=getattr(torch, dtype)).eval()
            self.vocoder = vocoder
        logging.info(f"Extractor:\n{self.feats_extract}")
        logging.info(f"Normalizer:\n{self.normalize}")
        logging.info(f"TTS:\n{self.tts}")
        if self.vocoder is not None:
            logging.info(f"Vocoder:\n{self.vocoder}")

        # setup decoding config
        decode_conf = {}
        decode_conf.update(use_teacher_forcing=use_teacher_forcing)
        if isinstance(self.tts, (Tacotron2, Transformer)):
            decode_conf.update(
                threshold=threshold,
                maxlenratio=maxlenratio,
                minlenratio=minlenratio,
            )
        if isinstance(self.tts, Tacotron2):
            decode_conf.update(
                use_att_constraint=use_att_constraint,
                forward_window=forward_window,
                backward_window=backward_window,
            )
        if isinstance(self.tts, (FastSpeech, FastSpeech2, VITS)):
            decode_conf.update(alpha=speed_control_alpha)
        if isinstance(self.tts, VITS):
            decode_conf.update(
                noise_scale=noise_scale,
                noise_scale_dur=noise_scale_dur,
            )
        self.decode_conf = decode_conf
예제 #6
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    threshold: float,
    minlenratio: float,
    maxlenratio: float,
    use_att_constraint: bool,
    backward_window: int,
    forward_window: int,
    allow_variable_data_keys: bool,
    vocoder_conf: dict,
):
    """Perform TTS model decoding."""
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build model
    model, train_args = TTSTask.build_model_from_file(train_config, model_file,
                                                      device)
    model.to(dtype=getattr(torch, dtype)).eval()
    tts = model.tts
    normalize = model.normalize
    logging.info(f"Normalization:\n{normalize}")
    logging.info(f"TTS:\n{tts}")

    # 3. Build data-iterator
    loader = TTSTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=TTSTask.build_preprocess_fn(train_args, False),
        collate_fn=TTSTask.build_collate_fn(train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Build converter from spectrogram to waveform
    if model.feats_extract is not None:
        vocoder_conf.update(model.feats_extract.get_parameters())
    if "n_fft" in vocoder_conf and "n_shift" in vocoder_conf and "fs" in vocoder_conf:
        spc2wav = Spectrogram2Waveform(**vocoder_conf)
        logging.info(f"Vocoder: {spc2wav}")
    else:
        spc2wav = None
        logging.info(
            "Vocoder is not used because vocoder_conf is not sufficient")

    # 5. Start for-loop
    output_dir = Path(output_dir)
    (output_dir / "norm").mkdir(parents=True, exist_ok=True)
    (output_dir / "denorm").mkdir(parents=True, exist_ok=True)
    (output_dir / "wav").mkdir(parents=True, exist_ok=True)
    (output_dir / "att_ws").mkdir(parents=True, exist_ok=True)
    (output_dir / "probs").mkdir(parents=True, exist_ok=True)

    with NpyScpWriter(
            output_dir / "norm",
            output_dir / "norm/feats.scp",
    ) as f, NpyScpWriter(output_dir / "denorm",
                         output_dir / "denorm/feats.scp") as g:
        for idx, (keys, batch) in enumerate(loader, 1):
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = to_device(batch, device)

            key = keys[0]
            # Change to single sequence and remove *_length
            # because inference() requires 1-seq, not mini-batch.
            _data = {
                k: v[0]
                for k, v in batch.items() if not k.endswith("_lengths")
            }
            start_time = time.perf_counter()

            _decode_conf = {
                "threshold": threshold,
                "maxlenratio": maxlenratio,
                "minlenratio": minlenratio,
            }
            if isinstance(tts, Tacotron2):
                _decode_conf.update({
                    "use_att_constraint": use_att_constraint,
                    "forward_window": forward_window,
                    "backward_window": backward_window,
                })
            outs, probs, att_ws = tts.inference(**_data, **_decode_conf)
            insize = next(iter(_data.values())).size(0) + 1
            logging.info("inference speed = {:.1f} frames / sec.".format(
                int(outs.size(0)) / (time.perf_counter() - start_time)))
            logging.info(f"{key} (size:{insize}->{outs.size(0)})")
            if outs.size(0) == insize * maxlenratio:
                logging.warning(
                    f"output length reaches maximum length ({key}).")
            f[key] = outs.cpu().numpy()

            # NOTE: normalize.inverse is in-place operation
            outs_denorm = normalize.inverse(outs[None])[0][0]
            g[key] = outs_denorm.cpu().numpy()

            # Lazy load to avoid the backend error
            matplotlib.use("Agg")
            import matplotlib.pyplot as plt
            from matplotlib.ticker import MaxNLocator

            # Plot attention weight
            att_ws = att_ws.cpu().numpy()

            if att_ws.ndim == 2:
                att_ws = att_ws[None][None]
            elif att_ws.ndim != 4:
                raise RuntimeError(f"Must be 2 or 4 dimension: {att_ws.ndim}")

            w, h = plt.figaspect(att_ws.shape[0] / att_ws.shape[1])
            fig = plt.Figure(figsize=(
                w * 1.3 * min(att_ws.shape[0], 2.5),
                h * 1.3 * min(att_ws.shape[1], 2.5),
            ))
            fig.suptitle(f"{key}")
            axes = fig.subplots(att_ws.shape[0], att_ws.shape[1])
            if len(att_ws) == 1:
                axes = [[axes]]
            for ax, att_w in zip(axes, att_ws):
                for ax_, att_w_ in zip(ax, att_w):
                    ax_.imshow(att_w_.astype(np.float32), aspect="auto")
                    ax_.set_xlabel("Input")
                    ax_.set_ylabel("Output")
                    ax_.xaxis.set_major_locator(MaxNLocator(integer=True))
                    ax_.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
            fig.savefig(output_dir / f"att_ws/{key}.png")
            fig.clf()

            # Plot stop token prediction
            probs = probs.cpu().numpy()

            fig = plt.Figure()
            ax = fig.add_subplot(1, 1, 1)
            ax.plot(probs)
            ax.set_title(f"{key}")
            ax.set_xlabel("Output")
            ax.set_ylabel("Stop probability")
            ax.set_ylim(0, 1)
            ax.grid(which="both")

            fig.tight_layout()
            fig.savefig(output_dir / f"probs/{key}.png")
            fig.clf()

            # TODO(kamo): Write scp
            if spc2wav is not None:
                wav = spc2wav(outs_denorm.cpu().numpy())
                sf.write(f"{output_dir}/wav/{key}.wav", wav, spc2wav.fs,
                         "PCM_16")
예제 #7
0
def get_parser():
    parser = TTSTask.get_parser()
    return parser
예제 #8
0
def test_add_arguments_help():
    parser = TTSTask.get_parser()
    with pytest.raises(SystemExit):
        parser.parse_args(["--help"])
예제 #9
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    threshold: float,
    minlenratio: float,
    maxlenratio: float,
    use_att_constraint: bool,
    backward_window: int,
    forward_window: int,
    allow_variable_data_keys: bool,
    vocoder_conf: dict,
):
    """Perform TTS model decoding."""
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build model
    model, train_args = TTSTask.build_model_from_file(train_config, model_file,
                                                      device)
    model.to(dtype=getattr(torch, dtype)).eval()
    tts = model.tts
    normalize = model.normalize
    logging.info(f"Normalization:\n{normalize}")
    logging.info(f"TTS:\n{tts}")

    # 3. Build data-iterator
    loader = TTSTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=TTSTask.build_preprocess_fn(train_args, False),
        collate_fn=TTSTask.build_collate_fn(train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Build converter from spectrogram to waveform
    if model.feats_extract is not None:
        vocoder_conf.update(model.feats_extract.get_parameters())
    if "n_fft" in vocoder_conf and "n_shift" in vocoder_conf and "fs" in vocoder_conf:
        spc2wav = Spectrogram2Waveform(**vocoder_conf)
        logging.info(f"Vocoder: {spc2wav}")
    else:
        spc2wav = None
        logging.info(
            "Vocoder is not used because vocoder_conf is not sufficient")

    # 5. Start for-loop
    output_dir = Path(output_dir)
    (output_dir / "norm").mkdir(parents=True, exist_ok=True)
    (output_dir / "denorm").mkdir(parents=True, exist_ok=True)
    (output_dir / "wav").mkdir(parents=True, exist_ok=True)

    # FIXME(kamo): I think we shouldn't depend on kaldi-format any more.
    #  How about numpy or HDF5?
    #  >>> with NpyScpWriter() as f:
    with kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(
            o=output_dir / "norm/feats")) as f, kaldiio.WriteHelper(
                "ark,scp:{o}.ark,{o}.scp".format(o=output_dir /
                                                 "denorm/feats")) as g:
        for idx, (keys, batch) in enumerate(loader, 1):
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            batch = to_device(batch, device)

            key = keys[0]
            # Change to single sequence and remove *_length
            # because inference() requires 1-seq, not mini-batch.
            _data = {
                k: v[0]
                for k, v in batch.items() if not k.endswith("_lengths")
            }
            start_time = time.perf_counter()

            # TODO(kamo): Now att_ws is not used.
            outs, probs, att_ws = tts.inference(
                **_data,
                threshold=threshold,
                maxlenratio=maxlenratio,
                minlenratio=minlenratio,
            )
            outs_denorm = normalize.inverse(outs[None])[0][0]
            insize = next(iter(_data.values())).size(0)
            logging.info("inference speed = {} msec / frame.".format(
                (time.perf_counter() - start_time) /
                (int(outs.size(0)) * 1000)))
            logging.info(f"{key} (size:{insize}->{outs.size(0)})")
            if outs.size(0) == insize * maxlenratio:
                logging.warning(
                    f"output length reaches maximum length ({key}).")
            f[key] = outs.cpu().numpy()
            g[key] = outs_denorm.cpu().numpy()

            # TODO(kamo): Write scp
            if spc2wav is not None:
                wav = spc2wav(outs_denorm.cpu().numpy())
                sf.write(f"{output_dir}/wav/{key}.wav", wav, spc2wav.fs,
                         "PCM_16")
예제 #10
0
def test_add_arguments():
    TTSTask.get_parser()
예제 #11
0
def test_print_config_and_load_it(tmp_path):
    config_file = tmp_path / "config.yaml"
    with config_file.open("w") as f:
        TTSTask.print_config(f)
    parser = TTSTask.get_parser()
    parser.parse_args(["--config", str(config_file)])
예제 #12
0
def test_main_with_no_args():
    with pytest.raises(SystemExit):
        TTSTask.main(cmd=[])
예제 #13
0
def test_main_print_config():
    with pytest.raises(SystemExit):
        TTSTask.main(cmd=["--print_config"])
예제 #14
0
def test_main_help():
    with pytest.raises(SystemExit):
        TTSTask.main(cmd=["--help"])