Ejemplo n.º 1
0
def pre_batch_transform(inputs, return_keys=None):
    s = inputs['audio_data']['speech_source']
    y = inputs['audio_data']['observation']
    S = stft(s, 512, 128)
    Y = stft(y, 512, 128)
    Y = einops.rearrange(Y, 't f -> t f')
    S = einops.rearrange(S, 'k t f -> t k f')
    X = S  # Same for MERL database
    num_frames = Y.shape[0]

    return_dict = dict()

    def maybe_add(key, value):
        if return_keys is None or key in return_keys:
            return_dict[key] = value

    maybe_add('example_id', inputs['example_id'])
    maybe_add('s', np.ascontiguousarray(s, np.float32))
    maybe_add('y', np.ascontiguousarray(y, np.float32))
    maybe_add('Y', np.ascontiguousarray(Y, np.complex64))
    maybe_add('X_abs', np.ascontiguousarray(np.abs(X), np.float32))
    maybe_add('Y_abs', np.ascontiguousarray(np.abs(Y), np.float32))
    maybe_add('num_frames', num_frames)
    maybe_add('cos_phase_difference', np.ascontiguousarray(
        np.cos(np.angle(Y[:, None, :]) - np.angle(X)), np.float32)
    )

    if return_keys is None or 'target_mask' in return_keys:
        return_dict['target_mask'] = np.ascontiguousarray(
            ideal_binary_mask(S, source_axis=-2), np.float32
        )

    return return_dict
Ejemplo n.º 2
0
 def test_compare_stft_to_numpy(self):
     X_numpy = stft(self.time_signal, size=self.size, shift=self.shift,
                    window_length=self.window_length, window=self.window,
                    fading=self.fading)
     X_numpy = np.concatenate([np.real(X_numpy), np.imag(X_numpy)], axis=-1)
     X_torch = self.stft(self.torch_signal).numpy()
     tc.assert_almost_equal(X_torch, X_numpy)
Ejemplo n.º 3
0
 def test_restore_time_signal_from_numpy_stft_and_torch_istft(self):
     X_numpy = stft(self.time_signal, size=self.size, shift=self.shift,
                    window_length=self.window_length, window=self.window,
                    fading=self.fading)
     x_torch = self.stft.inverse(torch.from_numpy(X_numpy))
     x_numpy = x_torch.numpy()[..., :self.time_signal.shape[-1]]
     tc.assert_almost_equal(x_numpy, self.time_signal)
Ejemplo n.º 4
0
    def test_stft_frame_count(self):
        stft = self.stft
        stft.fading = False
        x = torch.rand(size=[1023])
        X = stft(x)
        tc.assert_equal(X.shape, (1, self.num_features))

        x = torch.rand(size=[1024])
        X = stft(x)
        tc.assert_equal(X.shape, (1, self.num_features))

        x = torch.rand(size=[1025])
        X = stft(x)
        tc.assert_equal(X.shape, (2, self.num_features))

        stft.fading = True
        x = torch.rand(size=[1023])
        X = stft(x)
        tc.assert_equal(X.shape, (7, self.num_features))

        x = torch.rand(size=[1024])
        X = stft(x)
        tc.assert_equal(X.shape, (7, self.num_features))

        x = torch.rand(size=[1025])
        X = stft(x)
        tc.assert_equal(X.shape, (8, self.num_features))
Ejemplo n.º 5
0
    def test_stft_frame_count(self):
        stft = self.stft
        stft.fading = False
        x = torch.rand(size=[1019])
        X = stft(x)
        tc.assert_equal(X.shape, (50, self.fbins * 2))

        x = torch.rand(size=[1020])
        X = stft(x)
        tc.assert_equal(X.shape, (50, self.fbins * 2))

        x = torch.rand(size=[1021])
        X = stft(x)
        tc.assert_equal(X.shape, (51, self.fbins * 2))

        stft.fading = True
        x = torch.rand(size=[1019])
        X = stft(x)
        tc.assert_equal(X.shape, (52, self.fbins * 2))

        x = torch.rand(size=[1020])
        X = stft(x)
        tc.assert_equal(X.shape, (52, self.fbins * 2))

        x = torch.rand(size=[1021])
        X = stft(x)
        tc.assert_equal(X.shape, (53, self.fbins * 2))
Ejemplo n.º 6
0
 def test_compare_stft_to_numpy(self):
     X_numpy = stft(self.time_signal, size=self.size, shift=self.shift,
                    window_length=self.window_length, window=self.window,
                    fading=self.fading)
     X_torch = self.stft(self.torch_signal).numpy()
     tc.assert_almost_equal(X_torch, X_numpy)