示例#1
0
    def test_rank3_valid(self, dt):
        a, b, y_r = self._setup_rank3(dt)
        y = correlate(a, b, "valid")
        assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5])
        assert_equal(y.dtype, dt)

        # See gh-5897
        y = correlate(b, a, "valid")
        assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5][::-1, ::-1, ::-1])
        assert_equal(y.dtype, dt)
示例#2
0
    def test_rank1_valid(self, dt):
        a, b, y_r = self._setup_rank1(dt)
        y = correlate(a, b, 'valid')
        assert_array_almost_equal(y, y_r[1:4])
        assert_equal(y.dtype, dt)

        # See gh-5897
        y = correlate(b, a, 'valid')
        assert_array_almost_equal(y, y_r[1:4][::-1])
        assert_equal(y.dtype, dt)
示例#3
0
    def test_rank1_valid(self, dt):
        a, b, y_r = self._setup_rank1(dt, 'valid')
        y = correlate(a, b, 'valid')
        assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
        assert_equal(y.dtype, dt)

        # See gh-5897
        y = correlate(b, a, 'valid')
        assert_array_almost_equal(y,
                                  y_r[::-1].conj(),
                                  decimal=self.decimal(dt))
        assert_equal(y.dtype, dt)
示例#4
0
    def _setup_rank1(self, dt, mode):
        np.random.seed(9)
        a = np.random.randn(10).astype(dt)
        a += 1j * np.random.randn(10).astype(dt)
        b = np.random.randn(8).astype(dt)
        b += 1j * np.random.randn(8).astype(dt)

        y_r = (correlate(a.real, b.real, mode=mode) +
               correlate(a.imag, b.imag, mode=mode)).astype(dt)
        y_r += 1j * (-correlate(a.real, b.imag, mode=mode) +
                     correlate(a.imag, b.real, mode=mode))
        return a, b, y_r
示例#5
0
    def test_rank3(self, dt):
        a = np.random.randn(10, 8, 6).astype(dt)
        a += 1j * np.random.randn(10, 8, 6).astype(dt)
        b = np.random.randn(8, 6, 4).astype(dt)
        b += 1j * np.random.randn(8, 6, 4).astype(dt)

        y_r = (correlate(a.real, b.real) +
               correlate(a.imag, b.imag)).astype(dt)
        y_r += 1j * (-correlate(a.real, b.imag) + correlate(a.imag, b.real))

        y = correlate(a, b, 'full')
        assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
        assert_equal(y.dtype, dt)
示例#6
0
    def test_method(self, dt):
        if dt == Decimal:
            method = choose_conv_method([Decimal(4)], [Decimal(3)])
            assert_equal(method, 'direct')
        else:
            a, b, y_r = self._setup_rank3(dt)
            y_fft = correlate(a, b, method='fft')
            y_direct = correlate(a, b, method='direct')

            assert_array_almost_equal(y_r, y_fft)
            assert_array_almost_equal(y_r, y_direct)
            assert_equal(y_fft.dtype, dt)
            assert_equal(y_direct.dtype, dt)
示例#7
0
 def test_swap_full(self, dt):
     d = np.array([0. + 0.j, 1. + 1.j, 2. + 2.j], dtype=dt)
     k = np.array([1. + 3.j, 2. + 4.j, 3. + 5.j, 4. + 6.j], dtype=dt)
     y = correlate(d, k)
     assert_equal(
         y,
         [0. + 0.j, 10. - 2.j, 28. - 6.j, 22. - 6.j, 16. - 6.j, 8. - 4.j])
示例#8
0
 def test_swap_same(self, dt):
     d = [0. + 0.j, 1. + 1.j, 2. + 2.j]
     k = [1. + 3.j, 2. + 4.j, 3. + 5.j, 4. + 6.j]
     y = correlate(d, k, mode="same")
     assert_equal(y, [10. - 2.j, 28. - 6.j, 22. - 6.j])
示例#9
0
 def test_rank1_full(self, dt):
     a, b, y_r = self._setup_rank1(dt, 'full')
     y = correlate(a, b, 'full')
     assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
     assert_equal(y.dtype, dt)
示例#10
0
 def test_rank3_all(self, dt):
     a, b, y_r = self._setup_rank3(dt)
     y = correlate(a, b)
     assert_array_almost_equal(y, y_r)
     assert_equal(y.dtype, dt)
示例#11
0
 def test_rank3_same(self, dt):
     a, b, y_r = self._setup_rank3(dt)
     y = correlate(a, b, "same")
     assert_array_almost_equal(y, y_r[0:-1, 1:-1, 1:-2])
     assert_equal(y.dtype, dt)
示例#12
0
 def test_rank1_same(self, dt):
     a, b, y_r = self._setup_rank1(dt)
     y = correlate(a, b, 'same')
     assert_array_almost_equal(y, y_r[:-1])
     assert_equal(y.dtype, dt)