예제 #1
0
def load_forward_taco(
        checkpoint_path: str) -> Tuple[ForwardTacotron, Dict[str, Any]]:
    print(f'Loading tts checkpoint {checkpoint_path}')
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    config = checkpoint['config']
    tts_model = ForwardTacotron.from_config(config)
    tts_model.load_state_dict(checkpoint['model'])
    print(f'Loaded forward taco with step {tts_model.get_step()}')
    return tts_model, config
예제 #2
0
 def __init__(self, tts_path: str, voc_path: str, device='cuda'):
     self.device = torch.device(device)
     tts_checkpoint = torch.load(tts_path, map_location=self.device)
     tts_config = tts_checkpoint['config']
     tts_model = ForwardTacotron.from_config(tts_config)
     tts_model.load_state_dict(tts_checkpoint['model'])
     self.tts_model = tts_model
     self.wavernn = WaveRNN.from_checkpoint(voc_path)
     self.melgan = torch.hub.load('seungwonpark/melgan', 'melgan')
     self.melgan.to(device).eval()
     self.cleaner = Cleaner.from_config(tts_config)
     self.tokenizer = Tokenizer()
     self.dsp = DSP.from_config(tts_config)
예제 #3
0
        config['git_hash'] = try_get_git_hash()
    dsp = DSP.from_config(config)
    paths = Paths(config['data_path'], config['voc_model_id'],
                  config['tts_model_id'])

    assert len(os.listdir(paths.alg)) > 0, f'Could not find alignment files in {paths.alg}, please predict ' \
                                           f'alignments first with python train_tacotron.py --force_align!'

    force_gta = args.force_gta
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    print('Using device:', device)

    # Instantiate Forward TTS Model
    print('\nInitialising Forward TTS Model...\n')
    model = ForwardTacotron.from_config(config).to(device)
    optimizer = optim.Adam(model.parameters())
    restore_checkpoint(model=model,
                       optim=optimizer,
                       path=paths.forward_checkpoints / 'latest_model.pt',
                       device=device)

    if force_gta:
        print('Creating Ground Truth Aligned Dataset...\n')
        train_set, val_set = get_tts_datasets(paths.data,
                                              8,
                                              r=1,
                                              model_type='forward',
                                              filter_attention=False,
                                              max_mel_len=None)
        create_gta_features(model, train_set, val_set, paths.gta)