예제 #1
0
 def test_melspectrogram(self) -> None:
     config = read_config(self.resource_path / 'test_config.yaml')
     dsp = DSP.from_config(config)
     file = librosa.util.example_audio_file()
     y = dsp.load_wav(file)[:10000]
     mel = dsp.wav_to_mel(y)
     expected = np.load(self.resource_path / 'test_mel.npy')
     np.testing.assert_allclose(expected, mel)
예제 #2
0
    def generate_samples(self, model: WaveRNN,
                         session: VocSession) -> Tuple[float, list]:
        """
        Generates audio samples to cherry-pick models. To evaluate audio quality
        we calculate the l1 distance between mels of predictions and targets.
        """
        model.eval()
        mel_losses = []
        gen_wavs = []
        device = next(model.parameters()).device
        for i, sample in enumerate(session.val_set_samples, 1):
            m, x = sample['mel'], sample['x']
            if i > self.train_cfg['num_gen_samples']:
                break
            x = x[0].numpy()
            bits = 16 if self.dsp.voc_mode == 'MOL' else self.dsp.bits
            if self.dsp.mu_law and self.dsp.voc_mode != 'MOL':
                x = DSP.decode_mu_law(x, 2**bits, from_labels=True)
            else:
                x = DSP.label_2_float(x, bits)
            gen_wav = model.generate(mels=m,
                                     batched=self.train_cfg['gen_batched'],
                                     target=self.train_cfg['target'],
                                     overlap=self.train_cfg['overlap'],
                                     mu_law=self.dsp.mu_law,
                                     silent=True)

            gen_wavs.append(gen_wav)
            y_mel = self.dsp.wav_to_mel(x.squeeze(), normalize=False)
            y_mel = torch.tensor(y_mel).to(device)
            y_hat_mel = self.dsp.wav_to_mel(gen_wav, normalize=False)
            y_hat_mel = torch.tensor(y_hat_mel).to(device)
            loss = F.l1_loss(y_hat_mel, y_mel)
            mel_losses.append(loss.item())

            self.writer.add_audio(tag=f'Validation_Samples/target_{i}',
                                  snd_tensor=x,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)
            self.writer.add_audio(tag=f'Validation_Samples/generated_{i}',
                                  snd_tensor=gen_wav,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)

        return sum(mel_losses) / len(mel_losses), gen_wavs[0]
예제 #3
0
 def __init__(self, tts_path: str, voc_path: str, device='cuda'):
     self.device = torch.device(device)
     tts_checkpoint = torch.load(tts_path, map_location=self.device)
     tts_config = tts_checkpoint['config']
     tts_model = ForwardTacotron.from_config(tts_config)
     tts_model.load_state_dict(tts_checkpoint['model'])
     self.tts_model = tts_model
     self.wavernn = WaveRNN.from_checkpoint(voc_path)
     self.melgan = torch.hub.load('seungwonpark/melgan', 'melgan')
     self.melgan.to(device).eval()
     self.cleaner = Cleaner.from_config(tts_config)
     self.tokenizer = Tokenizer()
     self.dsp = DSP.from_config(tts_config)
예제 #4
0
    parser.add_argument('--gta',
                        '-g',
                        action='store_true',
                        help='train wavernn on GTA features')
    parser.add_argument('--config',
                        metavar='FILE',
                        default='config.yaml',
                        help='The config containing all hyperparams.')
    args = parser.parse_args()

    config = read_config(args.config)
    paths = Paths(config['data_path'], config['voc_model_id'],
                  config['tts_model_id'])
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    print('Using device:', device)
    print('\nInitialising Model...\n')
    voc_model = WaveRNN.from_config(config).to(device)
    dsp = DSP.from_config(config)
    assert np.cumprod(
        config['vocoder']['model']['upsample_factors'])[-1] == dsp.hop_length

    optimizer = optim.Adam(voc_model.parameters())
    restore_checkpoint(model=voc_model,
                       optim=optimizer,
                       path=paths.voc_checkpoints / 'latest_model.pt',
                       device=device)

    voc_trainer = VocTrainer(paths=paths, dsp=dsp, config=config)
    voc_trainer.train(voc_model, optimizer, train_gta=args.gta)
예제 #5
0
    hifigan_parser = subparsers.add_parser('hifigan')

    args = parser.parse_args()

    assert args.vocoder in {'griffinlim', 'wavernn', 'melgan', 'hifigan'}, \
        'Please provide a valid vocoder! Choices: [\'griffinlim\', \'wavernn\', \'melgan\', \'hifigan\']'

    checkpoint_path = args.checkpoint
    if checkpoint_path is None:
        config = read_config(args.config)
        paths = Paths(config['data_path'], config['voc_model_id'],
                      config['tts_model_id'])
        checkpoint_path = paths.forward_checkpoints / 'latest_model.pt'

    tts_model, config = load_forward_taco(checkpoint_path)
    dsp = DSP.from_config(config)

    voc_model, voc_dsp = None, None
    if args.vocoder == 'wavernn':
        voc_model, voc_config = load_wavernn(args.voc_checkpoint)
        voc_dsp = DSP.from_config(voc_config)

    out_path = Path('model_outputs')
    out_path.mkdir(parents=True, exist_ok=True)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    tts_model.to(device)
    cleaner = Cleaner.from_config(config)
    tokenizer = Tokenizer()

    print(f'Using device: {device}\n')
예제 #6
0
if __name__ == "__main__":
    args = docopt(__doc__)

    in_dir = args["<in_dir>"]
    out_dir = args["<out_dir>"]
    num_workers = args["--num_workers"]
    num_workers = cpu_count() if num_workers is None else int(num_workers)
    preset = args["--preset"]
    if preset is not None:
        with open(preset) as f:
            hparams.parse_json(f.read())
    # Override hyper parameters
    hparams.parse(args["--hparams"])
    assert hparams.name == "WaveRNN"

    dsp = DSP(hparams)
    quant_path = os.path.join(out_dir, 'quant/')
    mel_path = os.path.join(out_dir, 'mel/')
    os.makedirs(quant_path, exist_ok=True)
    os.makedirs(mel_path, exist_ok=True)

    wav_files = get_files(in_dir)

    # This will take a while depending on size of dataset
    dataset_ids = []
    for i, path in enumerate(wav_files) :
        dataset_id = path.split('/')[-1][:-4]
        dataset_ids += [dataset_id]
        m, x = convert_file(path)
        np.save(f'{mel_path}{dataset_id}.npy', m)
        np.save(f'{quant_path}{dataset_id}.npy', x)