def test_fft2_float(self):
     a = np.array(np.random.randn(2,4,8),"float32")
     a_split = a.real.copy(), a.imag.copy()
     with self.subTest(i = 0):
         self.assert_equal(fft.ifft2(fft.fft2(a, split_out = True), split_in = True)/8/4, a)
     with self.subTest(i = 1):
         out_split = fft.ifft2(fft.fft2(a, split_out = True), split_in = True, split_out = True)
         self.assert_equal(a_split[0], out_split[0]/8/4)
         self.assert_equal(a_split[1], out_split[1]/8/4)
     with self.subTest(i = 2):
         out_split = fft.ifft2(fft.fft2(a), split_out = True)
         self.assert_equal(a_split[0], out_split[0]/8/4)
         self.assert_equal(a_split[1], out_split[1]/8/4)  
     with self.subTest(i = 3):
         self.assert_equal(fft.ifft2(fft.fft2(a_split, split_in = True))/8/4, a)
     with self.subTest(i = 4):
         out_split = fft.ifft2(fft.fft2(a_split, split_in = True),split_out = True)
         self.assert_equal(a_split[0], out_split[0]/8/4)
         self.assert_equal(a_split[1], out_split[1]/8/4)   
     with self.subTest(i = 5):
         out_split = fft.ifft2(fft.fft2(a_split, split_in = True, split_out = True),split_out = True,split_in = True)
         self.assert_equal(a_split[0], out_split[0]/8/4)
         self.assert_equal(a_split[1], out_split[1]/8/4)    
    def test_fft2(self):
        for dtype in ("complex64", "complex128"):
            a0 = np.array(np.random.randn(2,4,8),dtype)
           
            f = fft.fft2(a0)
            with self.subTest(i = 0):
                a = a0.copy()
                fft.fft2(a, overwrite_x = True)
                self.assert_equal(a,f)
      
            with self.subTest(i = 1):
                real,imag = a0.real.copy(),a0.imag.copy()
                fr,fi = fft.fft2((real, imag), split_in = True, split_out = True)
                fft.fft2((real,imag), overwrite_x = True, split_in = True, split_out = True)
                self.assert_equal(fr,real)
                self.assert_equal(fi,imag)
                
            with self.subTest(i = 2):
                a = a0.copy()
                fft.fft2(a, split_out = True, overwrite_x = True)
                self.assert_equal(a,a0)

            with self.subTest(i = 3):
                real,imag = a0.real.copy(),a0.imag.copy()
                f = fft.fft2((real, imag), split_in = True)
                fft.fft2((real,imag), overwrite_x = True, split_in = True)
                self.assert_equal(real, f.real)
                self.assert_equal(imag,f.imag)
 def test_fft2_double_axes(self):
     a = np.array(np.random.randn(2,4,8),"complex128")
     b = a.reshape(2,8,4) 
     self.assert_equal(fft.fft2(a,axes = (0,1)), np.fft.fft2(a,axes = (0,1)))
     self.assert_equal(fft.fft2(b,axes = (0,1)), np.fft.fft2(b,axes = (0,1)))
 def test_wrong_fft2_shape(self):
     with self.assertRaises(ValueError):
         fft.fft2([[1,2,3],[1,2,3]])
 def test_fft2_double(self):
     a = np.array(np.random.randn(2,4,8),"complex128")
     b = a.reshape(2,8,4)
     self.assert_equal(fft.fft2(a), np.fft.fft2(a))
     self.assert_equal(fft.fft2(b), np.fft.fft2(b))