示例#1
0
def _rfft_transpose(t, fft_lengths):
  # The transpose of RFFT can't be expressed only in terms of irfft. Instead of
  # manually building up larger twiddle matrices (which would increase the
  # asymptotic complexity and is also rather complicated), we rely JAX to
  # transpose a naive RFFT implementation.
  dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
  dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))
  transpose = linear_transpose(
      partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal)
  result, = transpose(t)
  assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
  return result
示例#2
0
 def test_linear_transpose_imag(self):
   f = lambda x: x.imag
   transpose = api.linear_transpose(f, 1.j)
   actual, = transpose(1.)
   expected = -1.j
   self.assertEqual(actual, expected)
示例#3
0
 def test_linear_transpose_real(self):
   f = lambda x: x.real
   transpose = api.linear_transpose(f, 1.j)
   actual, = transpose(1.)
   expected = 1.
   self.assertEqual(actual, expected)