def _rfft_transpose(t, fft_lengths): # The transpose of RFFT can't be expressed only in terms of irfft. Instead of # manually building up larger twiddle matrices (which would increase the # asymptotic complexity and is also rather complicated), we rely JAX to # transpose a naive RFFT implementation. dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype)) transpose = linear_transpose( partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal) result, = transpose(t) assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype) return result
def test_linear_transpose_imag(self): f = lambda x: x.imag transpose = api.linear_transpose(f, 1.j) actual, = transpose(1.) expected = -1.j self.assertEqual(actual, expected)
def test_linear_transpose_real(self): f = lambda x: x.real transpose = api.linear_transpose(f, 1.j) actual, = transpose(1.) expected = 1. self.assertEqual(actual, expected)