Ejemplo n.º 1
0
 def evaluate(self, model: WaveRNN, val_set: Dataset) -> float:
     model.eval()
     val_loss = 0
     device = next(model.parameters()).device
     for i, (x, y, m) in enumerate(val_set, 1):
         x, m, y = x.to(device), m.to(device), y.to(device)
         with torch.no_grad():
             y_hat = model(x, m)
             if model.mode == 'RAW':
                 y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
             elif model.mode == 'MOL':
                 y = y.float()
             y = y.unsqueeze(-1)
             loss = self.loss_func(y_hat, y)
             val_loss += loss.item()
     return val_loss / len(val_set)
Ejemplo n.º 2
0
 def evaluate(self, model: WaveRNN, val_set: Dataset) -> float:
     model.eval()
     val_loss = 0
     device = next(model.parameters()).device
     for i, batch in enumerate(val_set, 1):
         batch = to_device(batch, device=device)
         x, y, m = batch['x'], batch['y'], batch['mel']
         with torch.no_grad():
             y_hat = model(x, m)
             if model.mode == 'RAW':
                 y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
             elif model.mode == 'MOL':
                 y = y.float()
             y = y.unsqueeze(-1)
             loss = self.loss_func(y_hat, y)
             val_loss += loss.item()
     return val_loss / len(val_set)
Ejemplo n.º 3
0
    def generate_samples(self, model: WaveRNN,
                         session: VocSession) -> Tuple[float, list]:
        """
        Generates audio samples to cherry-pick models. To evaluate audio quality
        we calculate the l1 distance between mels of predictions and targets.
        """
        model.eval()
        mel_losses = []
        gen_wavs = []
        device = next(model.parameters()).device
        for i, sample in enumerate(session.val_set_samples, 1):
            m, x = sample['mel'], sample['x']
            if i > self.train_cfg['num_gen_samples']:
                break
            x = x[0].numpy()
            bits = 16 if self.dsp.voc_mode == 'MOL' else self.dsp.bits
            if self.dsp.mu_law and self.dsp.voc_mode != 'MOL':
                x = DSP.decode_mu_law(x, 2**bits, from_labels=True)
            else:
                x = DSP.label_2_float(x, bits)
            gen_wav = model.generate(mels=m,
                                     batched=self.train_cfg['gen_batched'],
                                     target=self.train_cfg['target'],
                                     overlap=self.train_cfg['overlap'],
                                     mu_law=self.dsp.mu_law,
                                     silent=True)

            gen_wavs.append(gen_wav)
            y_mel = self.dsp.wav_to_mel(x.squeeze(), normalize=False)
            y_mel = torch.tensor(y_mel).to(device)
            y_hat_mel = self.dsp.wav_to_mel(gen_wav, normalize=False)
            y_hat_mel = torch.tensor(y_hat_mel).to(device)
            loss = F.l1_loss(y_hat_mel, y_mel)
            mel_losses.append(loss.item())

            self.writer.add_audio(tag=f'Validation_Samples/target_{i}',
                                  snd_tensor=x,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)
            self.writer.add_audio(tag=f'Validation_Samples/generated_{i}',
                                  snd_tensor=gen_wav,
                                  global_step=model.step,
                                  sample_rate=self.dsp.sample_rate)

        return sum(mel_losses) / len(mel_losses), gen_wavs[0]
Ejemplo n.º 4
0
    def generate_samples(self, model: WaveRNN,
                         session: VocSession) -> Tuple[float, list]:
        """
        Generates audio samples to cherry-pick models. To evaluate audio quality
        we calculate the l1 distance between mels of predictions and targets.
        """
        model.eval()
        mel_losses = []
        gen_wavs = []
        device = next(model.parameters()).device
        for i, (m, x) in enumerate(session.val_set_samples, 1):
            if i > hp.voc_gen_num_samples:
                break
            x = x[0].numpy()
            bits = 16 if hp.voc_mode == 'MOL' else hp.bits
            if hp.mu_law and hp.voc_mode != 'MOL':
                x = decode_mu_law(x, 2**bits, from_labels=True)
            else:
                x = label_2_float(x, bits)
            gen_wav = model.generate(mels=m,
                                     save_path=None,
                                     batched=hp.voc_gen_batched,
                                     target=hp.voc_target,
                                     overlap=hp.voc_overlap,
                                     mu_law=hp.mu_law,
                                     silent=True)

            gen_wavs.append(gen_wav)
            y_mel = raw_melspec(x.squeeze())
            y_mel = torch.tensor(y_mel).to(device)
            y_hat_mel = raw_melspec(gen_wav)
            y_hat_mel = torch.tensor(y_hat_mel).to(device)
            loss = F.l1_loss(y_hat_mel, y_mel)
            mel_losses.append(loss.item())

            self.writer.add_audio(tag=f'Validation_Samples/target_{i}',
                                  snd_tensor=x,
                                  global_step=model.step,
                                  sample_rate=hp.sample_rate)
            self.writer.add_audio(tag=f'Validation_Samples/generated_{i}',
                                  snd_tensor=gen_wav,
                                  global_step=model.step,
                                  sample_rate=hp.sample_rate)

        return sum(mel_losses) / len(mel_losses), gen_wavs[0]
Ejemplo n.º 5
0
                    upsample_factors=hp.voc_upsample_factors,
                    feat_dims=hp.num_mels,
                    compute_dims=hp.voc_compute_dims,
                    res_out_dims=hp.voc_res_out_dims,
                    res_blocks=hp.voc_res_blocks,
                    hop_length=hp.hop_length,
                    sample_rate=hp.sample_rate,
                    pad_val=hp.voc_pad_val,
                    mode=hp.voc_mode).cuda()

    paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

    restore_path = args.weights if args.weights else paths.voc_latest_weights

    model.restore(restore_path)
    model.eval()
    if hp.amp:
        model, _ = amp.initialize(model, [], opt_level='O3')

    simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
                  ('Target Samples', target if batched else 'N/A'),
                  ('Overlap Samples', overlap if batched else 'N/A')])

    k = model.get_step() // 1000

    for file_name in os.listdir(args.dir):
        if file_name.endswith('.npy'):
            mel = np.load(os.path.join(args.dir, file_name))
            mel = torch.tensor(mel).unsqueeze(0)

            batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
Ejemplo n.º 6
0
def main():
    # Parse Arguments
    parser = argparse.ArgumentParser(description='TTS Generator')

    parser.add_argument('--mel',
                        type=str,
                        help='[string/path] path to test mel file')

    parser.add_argument('--hp_file',
                        metavar='FILE',
                        default='hparams.py',
                        help='The file to use for the hyperparameters')

    parser.add_argument('--batched',
                        '-b',
                        dest='batched',
                        action='store_true',
                        help='Fast Batched Generation')

    parser.add_argument(
        '--voc_weights',
        type=str,
        help='[string/path] Load in different FastSpeech weights',
        default="pretrained/wave_800K.pyt")

    args = parser.parse_args()

    if not os.path.exists('onnx'):
        os.mkdir('onnx')

    hp.configure(args.hp_file)

    device = torch.device('cpu')
    print('Using device:', device)

    #####
    print('\nInitialising WaveRNN Model...\n')
    # Instantiate WaveRNN Model
    voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                        fc_dims=hp.voc_fc_dims,
                        bits=hp.bits,
                        pad=hp.voc_pad,
                        upsample_factors=hp.voc_upsample_factors,
                        feat_dims=hp.num_mels,
                        compute_dims=hp.voc_compute_dims,
                        res_out_dims=hp.voc_res_out_dims,
                        res_blocks=hp.voc_res_blocks,
                        hop_length=hp.hop_length,
                        sample_rate=hp.sample_rate,
                        mode=hp.voc_mode).to(device)

    voc_load_path = args.voc_weights
    voc_model.load(voc_load_path)

    voc_upsampler = WaveRNNUpsamplerONNX(voc_model, args.batched,
                                         hp.voc_target, hp.voc_overlap)
    voc_infer = WaveRNNONNX(voc_model)

    voc_model.eval()
    voc_upsampler.eval()
    voc_infer.eval()

    opset_version = 11

    with torch.no_grad():
        mels = np.load(args.mel)
        mels = torch.from_numpy(mels)
        mels = mels.unsqueeze(0)
        mels = voc_upsampler.pad_tensor(mels)

        mels_onnx = mels.clone()

        torch.onnx.export(voc_upsampler,
                          mels_onnx,
                          "./onnx/wavernn_upsampler.onnx",
                          opset_version=opset_version,
                          do_constant_folding=True,
                          input_names=["mels"],
                          output_names=["upsample_mels", "aux"])

        mels, aux = voc_upsampler(mels)
        mels = mels[:, 550:-550, :]

        mels, aux = voc_upsampler.fold(mels, aux)

        h1, h2, x = voc_infer.get_initial_parameters(mels)

        aux_split = voc_infer.split_aux(aux)

        b_size, seq_len, _ = mels.size()

        if seq_len:
            m_t = mels[:, 0, :]

            a1_t, a2_t, a3_t, a4_t = \
                (a[:, 0, :] for a in aux_split)

            rnn_input = (m_t, a1_t, a2_t, a3_t, a4_t, h1, h2, x)
            torch.onnx.export(voc_infer,
                              rnn_input,
                              "./onnx/wavernn_rnn.onnx",
                              opset_version=opset_version,
                              do_constant_folding=True,
                              input_names=[
                                  "m_t", "a1_t", "a2_t", "a3_t", "a4_t", "h1",
                                  "h2", "x"
                              ],
                              output_names=["logits", "h1", "h2"])

    print('Done!')