Exemplo n.º 1
0
    def test_torchscript_infer(self):
        """Scripted model outputs the same as eager mode"""

        upsample_scales = [5, 5, 8]
        n_rnn = 128
        n_fc = 128
        n_classes = 128
        hop_length = 200
        n_batch = 2
        n_time = 50
        n_freq = 25
        n_output = 64
        n_res_block = 2
        n_hidden = 32
        kernel_size = 5

        model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
                        n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
        model.eval()
        x = torch.rand(n_batch, n_freq, n_time)
        torch.random.manual_seed(0)
        out_eager = model.infer(x)
        torch.random.manual_seed(0)
        out_script = torch_script(model).infer(x)
        self.assertEqual(out_eager, out_script)
Exemplo n.º 2
0
    def test_infer_waveform(self):
        """Validate the output dimensions of a WaveRNN model's infer method.
        """

        upsample_scales = [5, 5, 8]
        n_rnn = 128
        n_fc = 128
        n_classes = 128
        hop_length = 200
        n_batch = 2
        n_time = 50
        n_freq = 25
        n_output = 64
        n_res_block = 2
        n_hidden = 32
        kernel_size = 5

        model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
                        n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)

        x = torch.rand(n_batch, n_freq, n_time)
        lengths = torch.tensor([n_time, n_time // 2])
        out, waveform_lengths = model.infer(x, lengths)

        assert out.size() == (n_batch, 1, hop_length * n_time)
        assert waveform_lengths[0] == hop_length * n_time
        assert waveform_lengths[1] == hop_length * n_time // 2
Exemplo n.º 3
0
    def test_waveform(self):
        """Validate the output dimensions of a WaveRNN model.
        """

        upsample_scales = [5, 5, 8]
        n_rnn = 512
        n_fc = 512
        n_classes = 512
        hop_length = 200
        n_batch = 2
        n_time = 200
        n_freq = 100
        n_output = 256
        n_res_block = 10
        n_hidden = 128
        kernel_size = 5

        model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
                        n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)

        x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
        mels = torch.rand(n_batch, 1, n_freq, n_time)
        out = model(x, mels)

        assert out.size() == (n_batch, 1,
                              hop_length * (n_time - kernel_size + 1),
                              n_classes)
Exemplo n.º 4
0
 def _get_wavernn(self, *, dl_kwargs=None):
     model = WaveRNN(**self._wavernn_params)
     url = f'{_BASE_URL}/{self._wavernn_path}'
     dl_kwargs = {} if dl_kwargs is None else dl_kwargs
     state_dict = load_state_dict_from_url(url, **dl_kwargs)
     model.load_state_dict(state_dict)
     model.eval()
     return model