コード例 #1
0
ファイル: sdk_api.py プロジェクト: waws520waws/ttskit
def load_models(mellotron_path=_mellotron_path,
                waveglow_path=_waveglow_path,
                ge2e_path=_ge2e_path,
                mellotron_hparams_path=_mellotron_hparams_path,
                **kwargs):
    global _use_waveglow
    global _dataloader

    if (mellotron_path == _mellotron_path and waveglow_path == _waveglow_path
            and ge2e_path == _ge2e_path
            and mellotron_hparams_path == _mellotron_hparams_path):
        download_resource()

    if _dataloader is not None:
        return
    if waveglow_path and waveglow_path not in {'_', 'gf', 'griffinlim'}:
        waveglow.load_waveglow_torch(waveglow_path)
        _use_waveglow = 1

    if mellotron_path:
        mellotron.load_mellotron_torch(mellotron_path)

    mellotron_hparams = mellotron.create_hparams(
        open(mellotron_hparams_path, encoding='utf8').read())
    mellotron_hparams.encoder_model_fpath = ge2e_path
    _dataloader = mellotron.TextMelLoader(audiopaths_and_text='',
                                          hparams=mellotron_hparams,
                                          speaker_ids=None,
                                          mode='test')
    return _dataloader
コード例 #2
0
ファイル: demo_inference.py プロジェクト: wushidong17/zhrtvc
                                                  'hparams.json').__str__()
        texts_path = workdir.joinpath('metadata', 'validation.txt').__str__()
        output_dir = workdir.joinpath(
            'test', f'{mellotron_stem}.{waveglow_stem}').__str__()
    else:
        mellotron_hparams_path = args.mellotron_hparams
        texts_path = args.input
        output_dir = args.output

    # 模型导入
    load_models(args)

    mellotron_hparams = mellotron.create_hparams(
        open(mellotron_hparams_path, encoding='utf8').read())
    dataloader = mellotron.TextMelLoader(audiopaths_and_text='',
                                         hparams=mellotron_hparams,
                                         speaker_ids=None,
                                         mode='test')

    waveglow_kwargs = json.loads(args.waveglow_kwargs)
    # 模型测试
    with tempfile.TemporaryDirectory() as tmpdir:
        audio = os.path.join(tmpdir, 'audio_example.wav')
        pydub.AudioSegment.silent(3000, frame_rate=args.sampling_rate).export(
            audio, format='wav')

        text = '这是个试水的例子。'
        speaker = 'speaker'
        text_data, style_data, speaker_data, f0_data = transform_mellotron_input_data(
            dataloader=dataloader,
            text=text,
            speaker=speaker,