def pre_batch_transform(inputs, return_keys=None): s = inputs['audio_data']['speech_source'] y = inputs['audio_data']['observation'] S = stft(s, 512, 128) Y = stft(y, 512, 128) Y = einops.rearrange(Y, 't f -> t f') S = einops.rearrange(S, 'k t f -> t k f') X = S # Same for MERL database num_frames = Y.shape[0] return_dict = dict() def maybe_add(key, value): if return_keys is None or key in return_keys: return_dict[key] = value maybe_add('example_id', inputs['example_id']) maybe_add('s', np.ascontiguousarray(s, np.float32)) maybe_add('y', np.ascontiguousarray(y, np.float32)) maybe_add('Y', np.ascontiguousarray(Y, np.complex64)) maybe_add('X_abs', np.ascontiguousarray(np.abs(X), np.float32)) maybe_add('Y_abs', np.ascontiguousarray(np.abs(Y), np.float32)) maybe_add('num_frames', num_frames) maybe_add('cos_phase_difference', np.ascontiguousarray( np.cos(np.angle(Y[:, None, :]) - np.angle(X)), np.float32) ) if return_keys is None or 'target_mask' in return_keys: return_dict['target_mask'] = np.ascontiguousarray( ideal_binary_mask(S, source_axis=-2), np.float32 ) return return_dict
def test_compare_stft_to_numpy(self): X_numpy = stft(self.time_signal, size=self.size, shift=self.shift, window_length=self.window_length, window=self.window, fading=self.fading) X_numpy = np.concatenate([np.real(X_numpy), np.imag(X_numpy)], axis=-1) X_torch = self.stft(self.torch_signal).numpy() tc.assert_almost_equal(X_torch, X_numpy)
def test_restore_time_signal_from_numpy_stft_and_torch_istft(self): X_numpy = stft(self.time_signal, size=self.size, shift=self.shift, window_length=self.window_length, window=self.window, fading=self.fading) x_torch = self.stft.inverse(torch.from_numpy(X_numpy)) x_numpy = x_torch.numpy()[..., :self.time_signal.shape[-1]] tc.assert_almost_equal(x_numpy, self.time_signal)
def test_stft_frame_count(self): stft = self.stft stft.fading = False x = torch.rand(size=[1023]) X = stft(x) tc.assert_equal(X.shape, (1, self.num_features)) x = torch.rand(size=[1024]) X = stft(x) tc.assert_equal(X.shape, (1, self.num_features)) x = torch.rand(size=[1025]) X = stft(x) tc.assert_equal(X.shape, (2, self.num_features)) stft.fading = True x = torch.rand(size=[1023]) X = stft(x) tc.assert_equal(X.shape, (7, self.num_features)) x = torch.rand(size=[1024]) X = stft(x) tc.assert_equal(X.shape, (7, self.num_features)) x = torch.rand(size=[1025]) X = stft(x) tc.assert_equal(X.shape, (8, self.num_features))
def test_stft_frame_count(self): stft = self.stft stft.fading = False x = torch.rand(size=[1019]) X = stft(x) tc.assert_equal(X.shape, (50, self.fbins * 2)) x = torch.rand(size=[1020]) X = stft(x) tc.assert_equal(X.shape, (50, self.fbins * 2)) x = torch.rand(size=[1021]) X = stft(x) tc.assert_equal(X.shape, (51, self.fbins * 2)) stft.fading = True x = torch.rand(size=[1019]) X = stft(x) tc.assert_equal(X.shape, (52, self.fbins * 2)) x = torch.rand(size=[1020]) X = stft(x) tc.assert_equal(X.shape, (52, self.fbins * 2)) x = torch.rand(size=[1021]) X = stft(x) tc.assert_equal(X.shape, (53, self.fbins * 2))
def test_compare_stft_to_numpy(self): X_numpy = stft(self.time_signal, size=self.size, shift=self.shift, window_length=self.window_length, window=self.window, fading=self.fading) X_torch = self.stft(self.torch_signal).numpy() tc.assert_almost_equal(X_torch, X_numpy)