Exemple #1
0
    def _assert_consistency(self, transform, tensor, *args):
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        transform = transform.to(device=self.device, dtype=self.dtype)

        ts_transform = torch_script(transform)

        output = transform(tensor, *args)
        ts_output = ts_transform(tensor, *args)
        self.assertEqual(ts_output, output)
    def _assert_consistency_complex(self, func, tensor):
        assert tensor.is_complex()
        tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
        ts_func = torch_script(func)

        torch.random.manual_seed(40)
        output = func(tensor)

        torch.random.manual_seed(40)
        ts_output = ts_func(tensor)

        self.assertEqual(ts_output, output)
    def _assert_consistency(self, func, tensor, shape_only=False):
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        ts_func = torch_script(func)

        torch.random.manual_seed(40)
        output = func(tensor)

        torch.random.manual_seed(40)
        ts_output = ts_func(tensor)

        if shape_only:
            ts_output = ts_output.shape
            output = output.shape
        self.assertEqual(ts_output, output)
Exemple #4
0
    def test_torchscript_consistency_forward(self):
        r"""Verify that scripting Emformer does not change the behavior of method `forward`."""
        input_dim = 128
        batch_size = 10
        num_frames = 400
        right_context_length = 1

        emformer = self._gen_model(input_dim, right_context_length)
        input, lengths = self._gen_inputs(input_dim, batch_size, num_frames,
                                          right_context_length)
        scripted = torch_script(emformer)

        ref_out, ref_len = emformer(input, lengths)
        scripted_out, scripted_len = scripted(input, lengths)

        self.assertEqual(ref_out, scripted_out)
        self.assertEqual(ref_len, scripted_len)
Exemple #5
0
    def _test_torchscript(self, model):
        model.eval()

        batch_size, num_frames = 3, 1024

        torch.manual_seed(0)
        waveforms = torch.randn(batch_size, num_frames)
        lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])

        ref_out, ref_len = model(waveforms, lengths)

        scripted = torch_script(model)

        hyp_out, hyp_len = scripted(waveforms, lengths)

        self.assertEqual(hyp_out, ref_out)
        self.assertEqual(hyp_len, ref_len)
Exemple #6
0
    def test_info_wav(self, dtype, sample_rate, num_channels):
        """`sox_io_backend.info` is torchscript-able and returns the same result"""
        audio_path = self.get_temp_path(
            f'{dtype}_{sample_rate}_{num_channels}.wav')
        data = get_wav_data(dtype,
                            num_channels,
                            normalize=False,
                            num_frames=1 * sample_rate)
        save_wav(audio_path, data, sample_rate)

        ts_info_func = torch_script(py_info_func)

        py_info = py_info_func(audio_path)
        ts_info = ts_info_func(audio_path)

        assert py_info.sample_rate == ts_info.sample_rate
        assert py_info.num_frames == ts_info.num_frames
        assert py_info.num_channels == ts_info.num_channels
Exemple #7
0
    def test_save_wav(self, dtype, sample_rate, num_channels):
        ts_save_func = torch_script(py_save_func)

        expected = get_wav_data(dtype, num_channels, normalize=False)
        py_path = self.get_temp_path(
            f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
        ts_path = self.get_temp_path(
            f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
        enc, bps = get_enc_params(dtype)

        py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
        ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)

        py_data, py_sr = load_wav(py_path, normalize=False)
        ts_data, ts_sr = load_wav(ts_path, normalize=False)

        self.assertEqual(sample_rate, py_sr)
        self.assertEqual(sample_rate, ts_sr)
        self.assertEqual(expected, py_data)
        self.assertEqual(expected, ts_data)
Exemple #8
0
    def test_torchscript_consistency_infer(self):
        r"""Verify that scripting Emformer does not change the behavior of method `infer`."""
        input_dim = 128
        batch_size = 10
        num_frames = 400
        right_context_length = 1

        emformer = self._gen_model(input_dim, right_context_length).eval()
        scripted = torch_script(emformer).eval()

        ref_state, scripted_state = None, None
        for _ in range(3):
            input, lengths = self._gen_inputs(input_dim, batch_size,
                                              num_frames, 0)
            ref_out, ref_len, ref_state = emformer.infer(
                input, lengths, ref_state)
            scripted_out, scripted_len, scripted_state = scripted.infer(
                input, lengths, scripted_state)
            self.assertEqual(ref_out, scripted_out)
            self.assertEqual(ref_len, scripted_len)
            self.assertEqual(ref_state, scripted_state)
Exemple #9
0
    def test_apply_effects_tensor(self, args):
        effects = args['effects']
        channels_first = True
        num_channels = args.get("num_channels", 2)
        input_sr = args.get("input_sample_rate", 8000)

        trans = SoxEffectTensorTransform(effects, input_sr, channels_first)

        trans = torch_script(trans)

        wav = get_sinusoid(frequency=800,
                           sample_rate=input_sr,
                           n_channels=num_channels,
                           dtype='float32',
                           channels_first=channels_first)
        found, sr_found = trans(wav)
        expected, sr_expected = sox_effects.apply_effects_tensor(
            wav, input_sr, effects, channels_first)

        assert sr_found == sr_expected
        self.assertEqual(expected, found)
Exemple #10
0
    def test_load_wav(self, dtype, sample_rate, num_channels, normalize,
                      channels_first):
        """`sox_io_backend.load` is torchscript-able and returns the same result"""
        audio_path = self.get_temp_path(
            f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
        data = get_wav_data(dtype,
                            num_channels,
                            normalize=False,
                            num_frames=1 * sample_rate)
        save_wav(audio_path, data, sample_rate)

        ts_load_func = torch_script(py_load_func)

        py_data, py_sr = py_load_func(audio_path,
                                      normalize=normalize,
                                      channels_first=channels_first)
        ts_data, ts_sr = ts_load_func(audio_path,
                                      normalize=normalize,
                                      channels_first=channels_first)

        self.assertEqual(py_sr, ts_sr)
        self.assertEqual(py_data, ts_data)