def test_dft_3d(): runtime = get_runtime() input_data = build_fft_input_data() input_tensor = ng.constant(input_data) input_axes = ng.constant(np.array([0, 1, 2], dtype=np.int64)) dft_node = ng.dft(input_tensor, input_axes) computation = runtime.computation(dft_node) dft_results = computation() np_results = np.fft.fftn(np.squeeze(input_data.view(dtype=np.complex64), axis=-1), axes=[0, 1, 2]) expected_results = np_results.view(dtype=np.float32).reshape((2, 10, 10, 2)) assert np.allclose(dft_results, expected_results, atol=0.0002)
def test_dft_2d_signal_size_2(): runtime = get_runtime() input_data = build_fft_input_data() input_tensor = ng.constant(input_data) input_axes = ng.constant(np.array([1, 2], dtype=np.int64)) input_signal_size = ng.constant(np.array([4, 5], dtype=np.int64)) dft_node = ng.dft(input_tensor, input_axes, input_signal_size) computation = runtime.computation(dft_node) dft_results = computation() np_results = np.fft.fft2(np.squeeze(input_data.view(dtype=np.complex64), axis=-1), s=[4, 5], axes=[1, 2]) expected_results = np_results.view(dtype=np.float32).reshape((2, 4, 5, 2)) assert np.allclose(dft_results, expected_results, atol=0.000062)
def test_dft_2d(): runtime = get_runtime() input_data = build_fft_input_data() input_tensor = ng.constant(input_data) input_axes = ng.constant(np.array([1, 2], dtype=np.int64)) dft_node = ng.dft(input_tensor, input_axes) computation = runtime.computation(dft_node) dft_results = computation() np_results = np.fft.fft2(np.squeeze(input_data.view(dtype=np.complex64), axis=-1), axes=[1, 2]).astype(np.complex64) expected_results = np.stack((np_results.real, np_results.imag), axis=-1) assert np.allclose(dft_results, expected_results, atol=0.000062)