예제 #1
0
    def test_compressed(self):
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore', message="warning: empty strings")
            s = readsav(path.join(DATA_PATH, 'various_compressed.sav'),
                        verbose=False)
        finally:
            warn_ctx.__exit__()

        assert_identical(s.i8u, np.uint8(234))
        assert_identical(s.f32, np.float32(-3.1234567e+37))
        assert_identical(
            s.c64,
            np.complex128(1.1987253647623157e+112 - 5.1987258887729157e+307j))
        assert_equal(s.array5d.shape, (4, 3, 4, 6, 5))
        assert_identical(s.arrays.a[0], np.array([1, 2, 3], dtype=np.int16))
        assert_identical(s.arrays.b[0],
                         np.array([4., 5., 6., 7.], dtype=np.float32))
        assert_identical(
            s.arrays.c[0],
            np.array([np.complex64(1 + 2j),
                      np.complex64(7 + 8j)]))
        assert_identical(
            s.arrays.d[0],
            np.array([b"cheese", b"bacon", b"spam"], dtype=np.object))
예제 #2
0
    def test_arrays_replicated_3d(self):
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore', message="warning: multi-dimensional structures")
            s = readsav(path.join(DATA_PATH, 'struct_pointer_arrays_replicated_3d.sav'), verbose=False)
        finally:
            warn_ctx.__exit__()

        # Check column types
        assert_true(s.arrays_rep.g.dtype.type is np.object_)
        assert_true(s.arrays_rep.h.dtype.type is np.object_)

        # Check column shapes
        assert_equal(s.arrays_rep.g.shape, (4, 3, 2))
        assert_equal(s.arrays_rep.h.shape, (4, 3, 2))

        # Check values
        for i in range(4):
            for j in range(3):
                for k in range(2):
                    assert_array_identical(s.arrays_rep.g[i, j, k], np.repeat(np.float32(4.), 2).astype(np.object_))
                    assert_array_identical(s.arrays_rep.h[i, j, k], np.repeat(np.float32(4.), 3).astype(np.object_))
                    assert_true(np.all(vect_id(s.arrays_rep.g[i, j, k]) == id(s.arrays_rep.g[0, 0, 0][0])))
                    assert_true(np.all(vect_id(s.arrays_rep.h[i, j, k]) == id(s.arrays_rep.h[0, 0, 0][0])))
예제 #3
0
    def test_integral(self):
        x = [1,1,1,2,2,2,4,4,4]
        y = [1,2,3,1,2,3,1,2,3]
        z = array([0,7,8,3,4,7,1,3,4])

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            # This seems to fail (ier=1, see ticket 1642).
            warnings.simplefilter('ignore', UserWarning)
            lut = SmoothBivariateSpline(x, y, z, kx=1, ky=1, s=0)
        finally:
            warn_ctx.__exit__()

        tx = [1,2,4]
        ty = [1,2,3]

        tz = lut(tx, ty)
        trpz = .25*(diff(tx)[:,None]*diff(ty)[None,:]
                    * (tz[:-1,:-1]+tz[1:,:-1]+tz[:-1,1:]+tz[1:,1:])).sum()
        assert_almost_equal(lut.integral(tx[0], tx[-1], ty[0], ty[-1]), trpz)

        lut2 = SmoothBivariateSpline(x, y, z, kx=2, ky=2, s=0)
        assert_almost_equal(lut2.integral(tx[0], tx[-1], ty[0], ty[-1]), trpz,
                            decimal=0)  # the quadratures give 23.75 and 23.85

        tz = lut(tx[:-1], ty[:-1])
        trpz = .25*(diff(tx[:-1])[:,None]*diff(ty[:-1])[None,:]
                    * (tz[:-1,:-1]+tz[1:,:-1]+tz[:-1,1:]+tz[1:,1:])).sum()
        assert_almost_equal(lut.integral(tx[0], tx[-2], ty[0], ty[-2]), trpz)
예제 #4
0
def test_find():
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter('ignore', DeprecationWarning)

        keys = find('weak mixing', disp=False)
        assert_equal(keys, ['weak mixing angle'])

        keys = find('qwertyuiop', disp=False)
        assert_equal(keys, [])

        keys = find('natural unit', disp=False)
        assert_equal(keys, sorted(['natural unit of velocity',
                                    'natural unit of action',
                                    'natural unit of action in eV s',
                                    'natural unit of mass',
                                    'natural unit of energy',
                                    'natural unit of energy in MeV',
                                    'natural unit of mom.um',
                                    'natural unit of mom.um in MeV/c',
                                    'natural unit of length',
                                    'natural unit of time']))
    finally:
        warn_ctx.__exit__()
예제 #5
0
def test_1d_shape():
    # Current 5 behavior is 1D -> column vector
    arr = np.arange(5)
    stream = BytesIO()
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        # silence warnings for tests
        warnings.simplefilter('ignore')
        savemat(stream, {'oned':arr}, format='5')
        vals = loadmat(stream)
        assert_equal(vals['oned'].shape, (5,1))
        # Current 4 behavior is 1D -> row vector
        stream = BytesIO()
        savemat(stream, {'oned':arr}, format='4')
        vals = loadmat(stream)
        assert_equal(vals['oned'].shape, (1, 5))
        for format in ('4', '5'):
            # can be explicitly 'column' for oned_as
            stream = BytesIO()
            savemat(stream, {'oned':arr},
                    format=format,
                    oned_as='column')
            vals = loadmat(stream)
            assert_equal(vals['oned'].shape, (5,1))
            # but different from 'row'
            stream = BytesIO()
            savemat(stream, {'oned':arr},
                    format=format,
                    oned_as='row')
            vals = loadmat(stream)
            assert_equal(vals['oned'].shape, (1,5))
    finally:
        warn_ctx.__exit__()
예제 #6
0
def test_cs_graph_components():
    D = np.eye(4, dtype=np.bool)

    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.filterwarnings("ignore",
                    message="`cs_graph_components` is deprecated")

        n_comp, flag = csgraph.cs_graph_components(csr_matrix(D))
        assert_(n_comp == 4)
        assert_equal(flag, [0, 1, 2, 3])

        D[0, 1] = D[1, 0] = 1

        n_comp, flag = csgraph.cs_graph_components(csr_matrix(D))
        assert_(n_comp == 3)
        assert_equal(flag, [0, 0, 1, 2])

        # A pathological case...
        D[2, 2] = 0
        n_comp, flag = csgraph.cs_graph_components(csr_matrix(D))
        assert_(n_comp == 2)
        assert_equal(flag, [0, 0, -2, 1])
    finally:
        warn_ctx.__exit__()
예제 #7
0
    def test_arrays_replicated_3d(self):
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore', message="warning: multi-dimensional structures")
            s = readsav(path.join(DATA_PATH, 'struct_pointer_arrays_replicated_3d.sav'), verbose=False)
        finally:
            warn_ctx.__exit__()

        # Check column types
        assert_true(s.arrays_rep.g.dtype.type is np.object_)
        assert_true(s.arrays_rep.h.dtype.type is np.object_)

        # Check column shapes
        assert_equal(s.arrays_rep.g.shape, (4, 3, 2))
        assert_equal(s.arrays_rep.h.shape, (4, 3, 2))

        # Check values
        for i in range(4):
            for j in range(3):
                for k in range(2):
                    assert_array_identical(s.arrays_rep.g[i, j, k], np.repeat(np.float32(4.), 2).astype(np.object_))
                    assert_array_identical(s.arrays_rep.h[i, j, k], np.repeat(np.float32(4.), 3).astype(np.object_))
                    assert_true(np.all(vect_id(s.arrays_rep.g[i, j, k]) == id(s.arrays_rep.g[0, 0, 0][0])))
                    assert_true(np.all(vect_id(s.arrays_rep.h[i, j, k]) == id(s.arrays_rep.h[0, 0, 0][0])))
예제 #8
0
파일: test_codata.py 프로젝트: 87/scipy
def test_find():
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter('ignore', DeprecationWarning)

        keys = find('weak mixing', disp=False)
        assert_equal(keys, ['weak mixing angle'])

        keys = find('qwertyuiop', disp=False)
        assert_equal(keys, [])

        keys = find('natural unit', disp=False)
        assert_equal(keys, sorted(['natural unit of velocity',
                                    'natural unit of action',
                                    'natural unit of action in eV s',
                                    'natural unit of mass',
                                    'natural unit of energy',
                                    'natural unit of energy in MeV',
                                    'natural unit of mom.um',
                                    'natural unit of mom.um in MeV/c',
                                    'natural unit of length',
                                    'natural unit of time']))
    finally:
        warn_ctx.__exit__()
예제 #9
0
 def test_skip_footer_with_invalid(self):
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         basestr = '1 1\n2 2\n3 3\n4 4\n5  \n6  \n7  \n'
         warnings.filterwarnings("ignore")
         # Footer too small to get rid of all invalid values
         assert_raises(ValueError,
                       textadapter.genfromtxt,
                       StringIO(basestr),
                       skip_footer=1)
         a = textadapter.genfromtxt(StringIO(basestr),
                                    skip_footer=1,
                                    invalid_raise=False)
         assert_equal(a, np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]))
         #
         a = textadapter.genfromtxt(StringIO(basestr), skip_footer=3)
         assert_equal(a, np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]))
         #
         basestr = '1 1\n2  \n3 3\n4 4\n5  \n6 6\n7 7\n'
         a = textadapter.genfromtxt(StringIO(basestr),
                                    skip_footer=1,
                                    invalid_raise=False)
         assert_equal(a, np.array([[1., 1.], [3., 3.], [4., 4.], [6., 6.]]))
         a = textadapter.genfromtxt(StringIO(basestr),
                                    skip_footer=3,
                                    invalid_raise=False)
         assert_equal(a, np.array([[1., 1.], [3., 3.], [4., 4.]]))
     finally:
         warn_ctx.__exit__()
예제 #10
0
def test_array_maskna_astype():
    dtsrc = [np.dtype(d) for d in '?bhilqpBHILQPefdgFDGSUO']
    #dtsrc.append(np.dtype([('b', np.int, (1,))]))
    dtsrc.append(np.dtype('datetime64[D]'))
    dtsrc.append(np.dtype('timedelta64[s]'))

    dtdst = [np.dtype(d) for d in '?bhilqpBHILQPefdgFDGSUO']
    #dtdst.append(np.dtype([('b', np.int, (1,))]))
    dtdst.append(np.dtype('datetime64[D]'))
    dtdst.append(np.dtype('timedelta64[s]'))

    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter("ignore", np.ComplexWarning)
        for dt1 in dtsrc:
            a = np.ones(2, dt1, maskna=1)
            a[1] = np.NA
            for dt2 in dtdst:
                msg = 'type %s to %s conversion' % (dt1, dt2)
                b = a.astype(dt2)
                assert_(b.flags.maskna, msg)
                assert_(b.flags.ownmaskna, msg)
                assert_(np.isna(b[1]), msg)
    finally:
        warn_ctx.__exit__()
예제 #11
0
def test_ksone_fit_freeze():
    """Regression test for ticket #1638.

    """
    d = np.array(
        [-0.18879233,  0.15734249,  0.18695107,  0.27908787, -0.248649,
         -0.2171497 ,  0.12233512,  0.15126419,  0.03119282,  0.4365294 ,
          0.08930393, -0.23509903,  0.28231224, -0.09974875, -0.25196048,
          0.11102028,  0.1427649 ,  0.10176452,  0.18754054,  0.25826724,
          0.05988819,  0.0531668 ,  0.21906056,  0.32106729,  0.2117662 ,
          0.10886442,  0.09375789,  0.24583286, -0.22968366, -0.07842391,
         -0.31195432, -0.21271196,  0.1114243 , -0.13293002,  0.01331725,
         -0.04330977, -0.09485776, -0.28434547,  0.22245721, -0.18518199,
         -0.10943985, -0.35243174,  0.06897665, -0.03553363, -0.0701746 ,
         -0.06037974,  0.37670779, -0.21684405])

    olderr = np.seterr(invalid='ignore')
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter('ignore', UserWarning)
        stats.ksone.fit(d)
    finally:
        warn_ctx.__exit__()
        np.seterr(**olderr)
예제 #12
0
    def test_bilinearity(self):
        x = [1,1,1,2,2,2,3,3,3]
        y = [1,2,3,1,2,3,1,2,3]
        z = [0,7,8,3,4,7,1,3,4]
        s = 0.1
        tx = [1+s,3-s]
        ty = [1+s,3-s]
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            # This seems to fail (ier=1, see ticket 1642).
            warnings.simplefilter('ignore', UserWarning)
            lut = LSQBivariateSpline(x,y,z,tx,ty,kx=1,ky=1)
        finally:
            warn_ctx.__exit__()

        tx, ty = lut.get_knots()

        for xa, xb in zip(tx[:-1], tx[1:]):
            for ya, yb in zip(ty[:-1], ty[1:]):
                for t in [0.1, 0.5, 0.9]:
                    for s in [0.3, 0.4, 0.7]:
                        xp = xa*(1-t) + xb*t
                        yp = ya*(1-s) + yb*s
                        zp = (+ lut(xa, ya)*(1-t)*(1-s)
                              + lut(xb, ya)*t*(1-s)
                              + lut(xa, yb)*(1-t)*s
                              + lut(xb, yb)*t*s)
                        assert_almost_equal(lut(xp,yp), zp)
def test_cs_graph_components():
    D = np.eye(4, dtype=np.bool)

    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.filterwarnings("ignore",
                                message="`cs_graph_components` is deprecated")

        n_comp, flag = csgraph.cs_graph_components(csr_matrix(D))
        assert_(n_comp == 4)
        assert_equal(flag, [0, 1, 2, 3])

        D[0, 1] = D[1, 0] = 1

        n_comp, flag = csgraph.cs_graph_components(csr_matrix(D))
        assert_(n_comp == 3)
        assert_equal(flag, [0, 0, 1, 2])

        # A pathological case...
        D[2, 2] = 0
        n_comp, flag = csgraph.cs_graph_components(csr_matrix(D))
        assert_(n_comp == 2)
        assert_equal(flag, [0, 0, -2, 1])
    finally:
        warn_ctx.__exit__()
예제 #14
0
    def test_integral(self):
        x = [1,1,1,2,2,2,4,4,4]
        y = [1,2,3,1,2,3,1,2,3]
        z = array([0,7,8,3,4,7,1,3,4])

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            # This seems to fail (ier=1, see ticket 1642).
            warnings.simplefilter('ignore', UserWarning)
            lut = SmoothBivariateSpline(x, y, z, kx=1, ky=1, s=0)
        finally:
            warn_ctx.__exit__()

        tx = [1,2,4]
        ty = [1,2,3]

        tz = lut(tx, ty)
        trpz = .25*(diff(tx)[:,None]*diff(ty)[None,:]
                    * (tz[:-1,:-1]+tz[1:,:-1]+tz[:-1,1:]+tz[1:,1:])).sum()
        assert_almost_equal(lut.integral(tx[0], tx[-1], ty[0], ty[-1]), trpz)

        lut2 = SmoothBivariateSpline(x, y, z, kx=2, ky=2, s=0)
        assert_almost_equal(lut2.integral(tx[0], tx[-1], ty[0], ty[-1]), trpz,
                            decimal=0)  # the quadratures give 23.75 and 23.85

        tz = lut(tx[:-1], ty[:-1])
        trpz = .25*(diff(tx[:-1])[:,None]*diff(ty[:-1])[None,:]
                    * (tz[:-1,:-1]+tz[1:,:-1]+tz[:-1,1:]+tz[1:,1:])).sum()
        assert_almost_equal(lut.integral(tx[0], tx[-2], ty[0], ty[-2]), trpz)
예제 #15
0
    def test_bilinearity(self):
        x = [1,1,1,2,2,2,3,3,3]
        y = [1,2,3,1,2,3,1,2,3]
        z = [0,7,8,3,4,7,1,3,4]
        s = 0.1
        tx = [1+s,3-s]
        ty = [1+s,3-s]
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            # This seems to fail (ier=1, see ticket 1642).
            warnings.simplefilter('ignore', UserWarning)
            lut = LSQBivariateSpline(x,y,z,tx,ty,kx=1,ky=1)
        finally:
            warn_ctx.__exit__()

        tx, ty = lut.get_knots()

        for xa, xb in zip(tx[:-1], tx[1:]):
            for ya, yb in zip(ty[:-1], ty[1:]):
                for t in [0.1, 0.5, 0.9]:
                    for s in [0.3, 0.4, 0.7]:
                        xp = xa*(1-t) + xb*t
                        yp = ya*(1-s) + yb*s
                        zp = (+ lut(xa, ya)*(1-t)*(1-s)
                              + lut(xb, ya)*t*(1-s)
                              + lut(xa, yb)*(1-t)*s
                              + lut(xb, yb)*t*s)
                        assert_almost_equal(lut(xp,yp), zp)
예제 #16
0
파일: test_maskna.py 프로젝트: ejmvar/numpy
def test_array_maskna_astype():
    dtsrc = [np.dtype(d) for d in '?bhilqpBHILQPefdgFDGSUO']
    #dtsrc.append(np.dtype([('b', np.int, (1,))]))
    dtsrc.append(np.dtype('datetime64[D]'))
    dtsrc.append(np.dtype('timedelta64[s]'))

    dtdst = [np.dtype(d) for d in '?bhilqpBHILQPefdgFDGSUO']
    #dtdst.append(np.dtype([('b', np.int, (1,))]))
    dtdst.append(np.dtype('datetime64[D]'))
    dtdst.append(np.dtype('timedelta64[s]'))

    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter("ignore", np.ComplexWarning)
        for dt1 in dtsrc:
            a = np.ones(2, dt1, maskna=1)
            a[1] = np.NA
            for dt2 in dtdst:
                msg = 'type %s to %s conversion' % (dt1, dt2)
                b = a.astype(dt2)
                assert_(b.flags.maskna, msg)
                assert_(b.flags.ownmaskna, msg)
                assert_(np.isna(b[1]), msg)
    finally:
        warn_ctx.__exit__()
예제 #17
0
class _DeprecationAccept:
    def setUp(self):
        self.mgr = WarningManager()
        self.mgr.__enter__()
        warnings.simplefilter("ignore", DeprecationWarning)

    def tearDown(self):
        self.mgr.__exit__()
예제 #18
0
class _DeprecationAccept:
    def setUp(self):
        self.mgr = WarningManager()
        self.mgr.__enter__()
        warnings.simplefilter("ignore", DeprecationWarning)

    def tearDown(self):
        self.mgr.__exit__()
예제 #19
0
 def test_complex_scalar_warning(self):
     for tp in [np.csingle, np.cdouble, np.clongdouble]:
         x = tp(1+2j)
         assert_warns(np.ComplexWarning, float, x)
         ctx = WarningManager()
         ctx.__enter__()
         warnings.simplefilter('ignore')
         assert_equal(float(x), float(x.real))
         ctx.__exit__()
예제 #20
0
 def test_complex_scalar_warning(self):
     for tp in [np.csingle, np.cdouble, np.clongdouble]:
         x = tp(1 + 2j)
         assert_warns(np.ComplexWarning, float, x)
         ctx = WarningManager()
         ctx.__enter__()
         warnings.simplefilter('ignore')
         assert_equal(float(x), float(x.real))
         ctx.__exit__()
예제 #21
0
 def test_summary(self):
     # smoke test
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         warnings.filterwarnings("ignore",
                                 "kurtosistest only valid for n>=20")
         summary = self.model.fit().summary()
     finally:
         warn_ctx.__exit__()
예제 #22
0
 def test_summary(self):
     # smoke test
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         warnings.filterwarnings("ignore",
                                 "kurtosistest only valid for n>=20")
         summary = self.model.fit().summary()
     finally:
         warn_ctx.__exit__()
예제 #23
0
 def test_empty_file(self):
     "Test that an empty file raises the proper warning."
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         warnings.filterwarnings("ignore", message="genfromtxt: Empty input file:")
         data = StringIO()
         test = iopro.genfromtxt(data)
         assert_equal(test, np.array([]))
     finally:
         warn_ctx.__exit__()
예제 #24
0
    def test_ndmin_keyword(self):
        c = StringIO()
        c.write('1,2,3\n4,5,6')
        c.seek(0)
        assert_raises(textadapter.DataTypeError,
                      textadapter.loadtxt,
                      c,
                      ndmin=3)
        c.seek(0)
        assert_raises(textadapter.DataTypeError,
                      textadapter.loadtxt,
                      c,
                      ndmin=1.5)
        c.seek(0)
        x = textadapter.loadtxt(c, dtype=int, delimiter=',', ndmin=1)
        a = np.array([[1, 2, 3], [4, 5, 6]])
        assert_array_equal(x, a)
        d = StringIO()
        d.write('0,1,2')
        d.seek(0)
        x = textadapter.loadtxt(d, dtype=int, delimiter=',', ndmin=2)
        assert_(x.shape == (1, 3))
        d.seek(0)
        x = textadapter.loadtxt(d, dtype=int, delimiter=',', ndmin=1)
        assert_(x.shape == (3, ))
        d.seek(0)
        x = textadapter.loadtxt(d, dtype=int, delimiter=',', ndmin=0)
        assert_(x.shape == (3, ))
        e = StringIO()
        e.write('0\n1\n2')
        e.seek(0)
        x = textadapter.loadtxt(e, dtype=int, delimiter=',', ndmin=2)
        assert_(x.shape == (3, 1))
        e.seek(0)
        x = textadapter.loadtxt(e, dtype=int, delimiter=',', ndmin=1)
        assert_(x.shape == (3, ))
        e.seek(0)
        x = textadapter.loadtxt(e, dtype=int, delimiter=',', ndmin=0)
        assert_(x.shape == (3, ))

        # Test ndmin kw with empty file.
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings("ignore",
                                    message="loadtxt: Empty input file:")
            f = StringIO()
            assert_(textadapter.loadtxt(f, ndmin=2).shape == (
                0,
                1,
            ))
            assert_(textadapter.loadtxt(f, ndmin=1).shape == (0, ))
        finally:
            warn_ctx.__exit__()
예제 #25
0
 def test_set_fields(self):
     "Tests setting fields."
     base = self.base.copy()
     mbase = base.view(mrecarray)
     mbase = mbase.copy()
     mbase.fill_value = (999999, 1e20, 'N/A')
     # Change the data, the mask should be conserved
     mbase.a._data[:] = 5
     assert_equal(mbase['a']._data, [5, 5, 5, 5, 5])
     assert_equal(mbase['a']._mask, [0, 1, 0, 0, 1])
     # Change the elements, and the mask will follow
     mbase.a = 1
     assert_equal(mbase['a']._data, [1] * 5)
     assert_equal(ma.getmaskarray(mbase['a']), [0] * 5)
     # Use to be _mask, now it's recordmask
     assert_equal(mbase.recordmask, [False] * 5)
     assert_equal(
         mbase._mask.tolist(),
         np.array([(0, 0, 0), (0, 1, 1), (0, 0, 0), (0, 0, 0), (0, 1, 1)],
                  dtype=bool))
     # Set a field to mask ........................
     mbase.c = masked
     # Use to be mask, and now it's still mask !
     assert_equal(mbase.c.mask, [1] * 5)
     assert_equal(mbase.c.recordmask, [1] * 5)
     assert_equal(ma.getmaskarray(mbase['c']), [1] * 5)
     assert_equal(ma.getdata(mbase['c']), [asbytes('N/A')] * 5)
     assert_equal(
         mbase._mask.tolist(),
         np.array([(0, 0, 1), (0, 1, 1), (0, 0, 1), (0, 0, 1), (0, 1, 1)],
                  dtype=bool))
     # Set fields by slices .......................
     mbase = base.view(mrecarray).copy()
     mbase.a[3:] = 5
     assert_equal(mbase.a, [1, 2, 3, 5, 5])
     assert_equal(mbase.a._mask, [0, 1, 0, 0, 0])
     mbase.b[3:] = masked
     assert_equal(mbase.b, base['b'])
     assert_equal(mbase.b._mask, [0, 1, 0, 1, 1])
     # Set fields globally..........................
     ndtype = [('alpha', '|S1'), ('num', int)]
     data = ma.array([('a', 1), ('b', 2), ('c', 3)], dtype=ndtype)
     rdata = data.view(MaskedRecords)
     val = ma.array([10, 20, 30], mask=[1, 0, 0])
     #
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         warnings.simplefilter("ignore")
         rdata['num'] = val
         assert_equal(rdata.num, val)
         assert_equal(rdata.num.mask, [1, 0, 0])
     finally:
         warn_ctx.__exit__()
예제 #26
0
def test_read_1():
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter('ignore', wavfile.WavFileWarning)
        rate, data = wavfile.read(datafile('test-44100-le-1ch-4bytes.wav'))
    finally:
        warn_ctx.__exit__()

    assert_equal(rate, 44100)
    assert_(np.issubdtype(data.dtype, np.int32))
    assert_equal(data.shape, (4410,))
예제 #27
0
 def test_set_fields(self):
     "Tests setting fields."
     base = self.base.copy()
     mbase = base.view(mrecarray)
     mbase = mbase.copy()
     mbase.fill_value = (999999,1e20,'N/A')
     # Change the data, the mask should be conserved
     mbase.a._data[:] = 5
     assert_equal(mbase['a']._data, [5,5,5,5,5])
     assert_equal(mbase['a']._mask, [0,1,0,0,1])
     # Change the elements, and the mask will follow
     mbase.a = 1
     assert_equal(mbase['a']._data, [1]*5)
     assert_equal(ma.getmaskarray(mbase['a']), [0]*5)
     # Use to be _mask, now it's recordmask
     assert_equal(mbase.recordmask, [False]*5)
     assert_equal(mbase._mask.tolist(),
                  np.array([(0,0,0),(0,1,1),(0,0,0),(0,0,0),(0,1,1)],
                           dtype=bool))
     # Set a field to mask ........................
     mbase.c = masked
     # Use to be mask, and now it's still mask !
     assert_equal(mbase.c.mask, [1]*5)
     assert_equal(mbase.c.recordmask, [1]*5)
     assert_equal(ma.getmaskarray(mbase['c']), [1]*5)
     assert_equal(ma.getdata(mbase['c']), [asbytes('N/A')]*5)
     assert_equal(mbase._mask.tolist(),
                  np.array([(0,0,1),(0,1,1),(0,0,1),(0,0,1),(0,1,1)],
                           dtype=bool))
     # Set fields by slices .......................
     mbase = base.view(mrecarray).copy()
     mbase.a[3:] = 5
     assert_equal(mbase.a, [1,2,3,5,5])
     assert_equal(mbase.a._mask, [0,1,0,0,0])
     mbase.b[3:] = masked
     assert_equal(mbase.b, base['b'])
     assert_equal(mbase.b._mask, [0,1,0,1,1])
     # Set fields globally..........................
     ndtype = [('alpha','|S1'),('num',int)]
     data = ma.array([('a',1),('b',2),('c',3)], dtype=ndtype)
     rdata = data.view(MaskedRecords)
     val = ma.array([10,20,30], mask=[1,0,0])
     #
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         warnings.simplefilter("ignore")
         rdata['num'] = val
         assert_equal(rdata.num, val)
         assert_equal(rdata.num.mask, [1,0,0])
     finally:
         warn_ctx.__exit__()
예제 #28
0
    def test_blas(self):
        a = array([[1,1,1]])
        b = array([[1],[1],[1]])

        # get_blas_funcs is deprecated, silence the warning
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter('ignore', DeprecationWarning)
            gemm, = get_blas_funcs(('gemm',),(a,b))
        finally:
            warn_ctx.__exit__()

        assert_array_almost_equal(gemm(1,a,b),[[3]],15)
예제 #29
0
 def test_empty_file(self):
     warn_ctx = WarningManager()
     warn_ctx.__enter__()
     try:
         warnings.filterwarnings("ignore",
                                 message="loadtxt: Empty input file:")
         c = StringIO()
         x = textadapter.loadtxt(c)
         assert_equal(x.shape, (0, ))
         x = textadapter.loadtxt(c, dtype=np.int64)
         assert_equal(x.shape, (0, ))
         assert_(x.dtype == np.int64)
     finally:
         warn_ctx.__exit__()
예제 #30
0
    def test_blas(self):
        a = array([[1, 1, 1]])
        b = array([[1], [1], [1]])

        # get_blas_funcs is deprecated, silence the warning
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter('ignore', DeprecationWarning)
            gemm, = get_blas_funcs(('gemm', ), (a, b))
        finally:
            warn_ctx.__exit__()

        assert_array_almost_equal(gemm(1, a, b), [[3]], 15)
예제 #31
0
def test_warnings():
    fname = pjoin(test_data_path, 'testdouble_7.1_GLNX86.mat')
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter('error')
        # This should not generate a warning
        mres = loadmat(fname, struct_as_record=True)
        # This neither
        mres = loadmat(fname, struct_as_record=False)
        # This should - because of deprecated system path search
        assert_raises(DeprecationWarning, find_mat_file, fname)
    finally:
        warn_ctx.__exit__()
예제 #32
0
def test_read_1():
    for mmap in [False, True]:
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter("ignore", wavfile.WavFileWarning)
            rate, data = wavfile.read(datafile("test-44100-le-1ch-4bytes.wav"), mmap=mmap)
        finally:
            warn_ctx.__exit__()

        assert_equal(rate, 44100)
        assert_(np.issubdtype(data.dtype, np.int32))
        assert_equal(data.shape, (4410,))

        del data
예제 #33
0
    def test_compressed(self):
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore', message="warning: empty strings")
            s = readsav(path.join(DATA_PATH, 'various_compressed.sav'), verbose=False)
        finally:
            warn_ctx.__exit__()

        assert_identical(s.i8u, np.uint8(234))
        assert_identical(s.f32, np.float32(-3.1234567e+37))
        assert_identical(s.c64, np.complex128(1.1987253647623157e+112-5.1987258887729157e+307j))
        assert_equal(s.array5d.shape, (4, 3, 4, 6, 5))
        assert_identical(s.arrays.a[0], np.array([1, 2, 3], dtype=np.int16))
        assert_identical(s.arrays.b[0], np.array([4., 5., 6., 7.], dtype=np.float32))
        assert_identical(s.arrays.c[0], np.array([np.complex64(1+2j), np.complex64(7+8j)]))
        assert_identical(s.arrays.d[0], np.array([b"cheese", b"bacon", b"spam"], dtype=np.object))
예제 #34
0
    def test_approx(self):
        ramsay = np.array((111, 107, 100, 99, 102, 106, 109, 108, 104, 99, 101,
                           96, 97, 102, 107, 113, 116, 113, 110, 98))
        parekh = np.array((107, 108, 106, 98, 105, 103, 110, 105, 104, 100, 96,
                           108, 103, 104, 114, 114, 113, 108, 106, 99))

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings(
                'ignore', message="Ties preclude use of exact statistic.")
            W, pval = stats.ansari(ramsay, parekh)
        finally:
            warn_ctx.__exit__()

        assert_almost_equal(W, 185.5, 11)
        assert_almost_equal(pval, 0.18145819972867083, 11)
예제 #35
0
    def test_approx(self):
        ramsay = np.array((111, 107, 100, 99, 102, 106, 109, 108, 104, 99,
                           101, 96, 97, 102, 107, 113, 116, 113, 110, 98))
        parekh = np.array((107, 108, 106, 98, 105, 103, 110, 105, 104,
                           100, 96, 108, 103, 104, 114, 114, 113, 108, 106, 99))

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore',
                        message="Ties preclude use of exact statistic.")
            W, pval = stats.ansari(ramsay, parekh)
        finally:
            warn_ctx.__exit__()

        assert_almost_equal(W,185.5,11)
        assert_almost_equal(pval,0.18145819972867083,11)
예제 #36
0
파일: test_vq.py 프로젝트: rblomberg/scipy
    def test_kmeans2_init(self):
        """Testing that kmeans2 init methods work."""
        data = np.fromfile(DATAFILE1, sep=", ")
        data = data.reshape((200, 2))

        kmeans2(data, 3, minit='points')
        kmeans2(data[:, :1], 3, minit='points')  # special case (1-D)

        # minit='random' can give warnings, filter those
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore',
                        message="One of the clusters is empty. Re-run")
            kmeans2(data, 3, minit='random')
            kmeans2(data[:, :1], 3, minit='random')  # special case (1-D)
        finally:
            warn_ctx.__exit__()
예제 #37
0
    def test_kmeans_lost_cluster(self):
        """This will cause kmean to have a cluster with no points."""
        data = np.fromfile(DATAFILE1, sep=", ")
        data = data.reshape((200, 2))
        initk = np.array([[-1.8127404, -0.67128041], [2.04621601, 0.07401111],
                          [-2.31149087, -0.05160469]])

        res = kmeans(data, initk)

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter('ignore', UserWarning)
            res = kmeans2(data, initk, missing='warn')
        finally:
            warn_ctx.__exit__()

        assert_raises(ClusterError, kmeans2, data, initk, missing='raise')
예제 #38
0
파일: test_vq.py 프로젝트: beiko-lab/gengis
    def test_kmeans2_init(self):
        """Testing that kmeans2 init methods work."""
        data = np.fromfile(DATAFILE1, sep=", ")
        data = data.reshape((200, 2))

        kmeans2(data, 3, minit='points')
        kmeans2(data[:, :1], 3, minit='points')  # special case (1-D)

        # minit='random' can give warnings, filter those
        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.filterwarnings('ignore',
                        message="One of the clusters is empty. Re-run")
            kmeans2(data, 3, minit='random')
            kmeans2(data[:, :1], 3, minit='random')  # special case (1-D)
        finally:
            warn_ctx.__exit__()
예제 #39
0
def test_mat4_3d():
    # test behavior when writing 3D arrays to matlab 4 files
    stream = BytesIO()
    arr = np.arange(24).reshape((2, 3, 4))

    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter("error")
        assert_raises(DeprecationWarning, savemat_future, stream, {"a": arr}, True, "4")
        # For now, we save a 3D array as 2D
        warnings.simplefilter("ignore")
        savemat_future(stream, {"a": arr}, format="4")
    finally:
        warn_ctx.__exit__()

    d = loadmat(stream)
    assert_array_equal(d["a"], arr.reshape((6, 4)))
예제 #40
0
def test_mat4_3d():
    # test behavior when writing 3D arrays to matlab 4 files
    stream = BytesIO()
    arr = np.arange(24).reshape((2,3,4))

    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        warnings.simplefilter('error')
        assert_raises(DeprecationWarning, savemat_future,
                      stream, {'a': arr}, True, '4')
        # For now, we save a 3D array as 2D
        warnings.simplefilter('ignore')
        savemat_future(stream, {'a': arr}, format='4')
    finally:
        warn_ctx.__exit__()

    d = loadmat(stream)
    assert_array_equal(d['a'], arr.reshape((6,4)))
예제 #41
0
파일: test_vq.py 프로젝트: beiko-lab/gengis
    def test_kmeans_lost_cluster(self):
        """This will cause kmean to have a cluster with no points."""
        data = np.fromfile(DATAFILE1, sep=", ")
        data = data.reshape((200, 2))
        initk = np.array([[-1.8127404, -0.67128041],
                         [2.04621601, 0.07401111],
                         [-2.31149087,-0.05160469]])

        res = kmeans(data, initk)

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter('ignore', UserWarning)
            res = kmeans2(data, initk, missing='warn')
        finally:
            warn_ctx.__exit__()

        assert_raises(ClusterError, kmeans2, data, initk, missing='raise')
예제 #42
0
    def test_safe_casting(self):
        # In old versions of numpy, in-place operations used the 'unsafe'
        # casting rules. In some future version, 'same_kind' will become the
        # default.
        a = np.array([1, 2, 3], dtype=int)
        # Non-in-place addition is fine
        assert_array_equal(assert_no_warnings(np.add, a, 1.1), [2.1, 3.1, 4.1])
        assert_warns(DeprecationWarning, np.add, a, 1.1, out=a)
        assert_array_equal(a, [2, 3, 4])

        def add_inplace(a, b):
            a += b

        assert_warns(DeprecationWarning, add_inplace, a, 1.1)
        assert_array_equal(a, [3, 4, 5])
        # Make sure that explicitly overriding the warning is allowed:
        assert_no_warnings(np.add, a, 1.1, out=a, casting="unsafe")
        assert_array_equal(a, [4, 5, 6])

        # There's no way to propagate exceptions from the place where we issue
        # this deprecation warning, so we must throw the exception away
        # entirely rather than cause it to be raised at some other point, or
        # trigger some other unsuspecting if (PyErr_Occurred()) { ...} at some
        # other location entirely.
        if sys.version_info[0] >= 3:
            from io import StringIO
        else:
            from StringIO import StringIO

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter("error")
            old_stderr = sys.stderr
            sys.stderr = StringIO()
            # No error, but dumps to stderr
            a += 1.1
            # No error on the next bit of code executed either
            1 + 1
            assert_("Implicitly casting" in sys.stderr.getvalue())
        finally:
            sys.stderr = old_stderr
            warn_ctx.__exit__()
예제 #43
0
    def test_safe_casting(self):
        # In old versions of numpy, in-place operations used the 'unsafe'
        # casting rules. In some future version, 'same_kind' will become the
        # default.
        a = np.array([1, 2, 3], dtype=int)
        # Non-in-place addition is fine
        assert_array_equal(assert_no_warnings(np.add, a, 1.1),
                           [2.1, 3.1, 4.1])
        assert_warns(DeprecationWarning, np.add, a, 1.1, out=a)
        assert_array_equal(a, [2, 3, 4])
        def add_inplace(a, b):
            a += b
        assert_warns(DeprecationWarning, add_inplace, a, 1.1)
        assert_array_equal(a, [3, 4, 5])
        # Make sure that explicitly overriding the warning is allowed:
        assert_no_warnings(np.add, a, 1.1, out=a, casting="unsafe")
        assert_array_equal(a, [4, 5, 6])

        # There's no way to propagate exceptions from the place where we issue
        # this deprecation warning, so we must throw the exception away
        # entirely rather than cause it to be raised at some other point, or
        # trigger some other unsuspecting if (PyErr_Occurred()) { ...} at some
        # other location entirely.
        if sys.version_info[0] >= 3:
            from io import StringIO
        else:
            from io import StringIO

        warn_ctx = WarningManager()
        warn_ctx.__enter__()
        try:
            warnings.simplefilter("error")
            old_stderr = sys.stderr
            sys.stderr = StringIO()
            # No error, but dumps to stderr
            a += 1.1
            # No error on the next bit of code executed either
            1 + 1
            assert_("Implicitly casting" in sys.stderr.getvalue())
        finally:
            sys.stderr = old_stderr
            warn_ctx.__exit__()
예제 #44
0
    def test_coercion(self):
        def res_type(a, b):
            return np.add(a, b).dtype

        ctx = WarningManager()
        ctx.__enter__()
        warnings.simplefilter('ignore', np.ComplexWarning)

        self.check_promotion_cases(res_type)

        f64 = float64(0)
        c64 = complex64(0)
        ## Scalars do not coerce to complex if the value is real
        #assert_equal(res_type(c64,array([f64])), np.dtype(float64))
        # But they do if the value is complex
        assert_equal(res_type(complex64(3j),array([f64])),
                                                    np.dtype(complex128))

        # Scalars do coerce to complex even if the value is real
        # This is so "a+0j" can be reliably used to make something complex.
        assert_equal(res_type(c64,array([f64])), np.dtype(complex128))

        ctx.__exit__()
 def _deprecated_imp(*args, **kwargs):
     # Poor man's replacement for the with statement
     ctx = WarningManager(record=True)
     l = ctx.__enter__()
     warnings.simplefilter('always')
     try:
         f(*args, **kwargs)
         if not len(l) > 0:
             raise AssertionError("No warning raised when calling %s"
                     % f.__name__)
         if not l[0].category is DeprecationWarning:
             raise AssertionError("First warning for %s is not a " \
                     "DeprecationWarning( is %s)" % (f.__name__, l[0]))
     finally:
         ctx.__exit__()
예제 #46
0
 def _deprecated_imp(*args, **kwargs):
     # Poor man's replacement for the with statement
     ctx = WarningManager(record=True)
     l = ctx.__enter__()
     warnings.simplefilter('always')
     try:
         f(*args, **kwargs)
         if not len(l) > 0:
             raise AssertionError("No warning raised when calling %s" %
                                  f.__name__)
         if not l[0].category is DeprecationWarning:
             raise AssertionError("First warning for %s is not a " \
                     "DeprecationWarning( is %s)" % (f.__name__, l[0]))
     finally:
         ctx.__exit__()
예제 #47
0
==================== =========================================================
Internal functions
==============================================================================
get_state            Get tuple representing internal state of generator.
set_state            Set state of generator.
==================== =========================================================

"""
# To get sub-modules
from .info import __doc__, __all__

import warnings
from numpy.testing.utils import WarningManager

warn_ctx = WarningManager()
warn_ctx.__enter__()
try:
    warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
    from .mtrand import *
finally:
    warn_ctx.__exit__()

# Some aliases:
ranf = random = sample = random_sample
__all__.extend(['ranf','random','sample'])

def __RandomState_ctor():
    """Return a RandomState instance.

    This function exists solely to assist (un)pickling.
    """
예제 #48
0
파일: test_einsum.py 프로젝트: yuj18/numpy
    def check_einsum_sums(self, dtype):
        # Check various sums.  Does many sizes to exercise unrolled loops.

        # sum(a, axis=-1)
        for n in range(1,17):
            a = np.arange(n, dtype=dtype)
            assert_equal(np.einsum("i->", a), np.sum(a, axis=-1).astype(dtype))
            assert_equal(np.einsum(a, [0], []),
                         np.sum(a, axis=-1).astype(dtype))

        for n in range(1,17):
            a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
            assert_equal(np.einsum("...i->...", a),
                         np.sum(a, axis=-1).astype(dtype))
            assert_equal(np.einsum(a, [Ellipsis,0], [Ellipsis]),
                         np.sum(a, axis=-1).astype(dtype))

        # sum(a, axis=0)
        for n in range(1,17):
            a = np.arange(2*n, dtype=dtype).reshape(2,n)
            assert_equal(np.einsum("i...->...", a),
                         np.sum(a, axis=0).astype(dtype))
            assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]),
                         np.sum(a, axis=0).astype(dtype))

        for n in range(1,17):
            a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
            assert_equal(np.einsum("i...->...", a),
                         np.sum(a, axis=0).astype(dtype))
            assert_equal(np.einsum(a, [0,Ellipsis], [Ellipsis]),
                         np.sum(a, axis=0).astype(dtype))

        # trace(a)
        for n in range(1,17):
            a = np.arange(n*n, dtype=dtype).reshape(n,n)
            assert_equal(np.einsum("ii", a), np.trace(a).astype(dtype))
            assert_equal(np.einsum(a, [0,0]), np.trace(a).astype(dtype))

        # multiply(a, b)
        for n in range(1,17):
            a = np.arange(3*n, dtype=dtype).reshape(3,n)
            b = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
            assert_equal(np.einsum("..., ...", a, b), np.multiply(a, b))
            assert_equal(np.einsum(a, [Ellipsis], b, [Ellipsis]),
                         np.multiply(a, b))

        # inner(a,b)
        for n in range(1,17):
            a = np.arange(2*3*n, dtype=dtype).reshape(2,3,n)
            b = np.arange(n, dtype=dtype)
            assert_equal(np.einsum("...i, ...i", a, b), np.inner(a, b))
            assert_equal(np.einsum(a, [Ellipsis,0], b, [Ellipsis,0]),
                         np.inner(a, b))

        for n in range(1,11):
            a = np.arange(n*3*2, dtype=dtype).reshape(n,3,2)
            b = np.arange(n, dtype=dtype)
            assert_equal(np.einsum("i..., i...", a, b), np.inner(a.T, b.T).T)
            assert_equal(np.einsum(a, [0,Ellipsis], b, [0,Ellipsis]),
                         np.inner(a.T, b.T).T)

        # outer(a,b)
        for n in range(1,17):
            a = np.arange(3, dtype=dtype)+1
            b = np.arange(n, dtype=dtype)+1
            assert_equal(np.einsum("i,j", a, b), np.outer(a, b))
            assert_equal(np.einsum(a, [0], b, [1]), np.outer(a, b))

        # Suppress the complex warnings for the 'as f8' tests
        ctx = WarningManager()
        ctx.__enter__()
        try:
            warnings.simplefilter('ignore', np.ComplexWarning)

            # matvec(a,b) / a.dot(b) where a is matrix, b is vector
            for n in range(1,17):
                a = np.arange(4*n, dtype=dtype).reshape(4,n)
                b = np.arange(n, dtype=dtype)
                assert_equal(np.einsum("ij, j", a, b), np.dot(a, b))
                assert_equal(np.einsum(a, [0,1], b, [1]), np.dot(a, b))

                c = np.arange(4, dtype=dtype)
                np.einsum("ij,j", a, b, out=c,
                            dtype='f8', casting='unsafe')
                assert_equal(c,
                            np.dot(a.astype('f8'),
                                   b.astype('f8')).astype(dtype))
                c[...] = 0
                np.einsum(a, [0,1], b, [1], out=c,
                            dtype='f8', casting='unsafe')
                assert_equal(c,
                            np.dot(a.astype('f8'),
                                   b.astype('f8')).astype(dtype))

            for n in range(1,17):
                a = np.arange(4*n, dtype=dtype).reshape(4,n)
                b = np.arange(n, dtype=dtype)
                assert_equal(np.einsum("ji,j", a.T, b.T), np.dot(b.T, a.T))
                assert_equal(np.einsum(a.T, [1,0], b.T, [1]), np.dot(b.T, a.T))

                c = np.arange(4, dtype=dtype)
                np.einsum("ji,j", a.T, b.T, out=c, dtype='f8', casting='unsafe')
                assert_equal(c,
                        np.dot(b.T.astype('f8'),
                               a.T.astype('f8')).astype(dtype))
                c[...] = 0
                np.einsum(a.T, [1,0], b.T, [1], out=c,
                            dtype='f8', casting='unsafe')
                assert_equal(c,
                        np.dot(b.T.astype('f8'),
                               a.T.astype('f8')).astype(dtype))

            # matmat(a,b) / a.dot(b) where a is matrix, b is matrix
            for n in range(1,17):
                if n < 8 or dtype != 'f2':
                    a = np.arange(4*n, dtype=dtype).reshape(4,n)
                    b = np.arange(n*6, dtype=dtype).reshape(n,6)
                    assert_equal(np.einsum("ij,jk", a, b), np.dot(a, b))
                    assert_equal(np.einsum(a, [0,1], b, [1,2]), np.dot(a, b))

            for n in range(1,17):
                a = np.arange(4*n, dtype=dtype).reshape(4,n)
                b = np.arange(n*6, dtype=dtype).reshape(n,6)
                c = np.arange(24, dtype=dtype).reshape(4,6)
                np.einsum("ij,jk", a, b, out=c, dtype='f8', casting='unsafe')
                assert_equal(c,
                            np.dot(a.astype('f8'),
                                   b.astype('f8')).astype(dtype))
                c[...] = 0
                np.einsum(a, [0,1], b, [1,2], out=c,
                                dtype='f8', casting='unsafe')
                assert_equal(c,
                            np.dot(a.astype('f8'),
                                   b.astype('f8')).astype(dtype))

            # matrix triple product (note this is not currently an efficient
            # way to multiply 3 matrices)
            a = np.arange(12, dtype=dtype).reshape(3,4)
            b = np.arange(20, dtype=dtype).reshape(4,5)
            c = np.arange(30, dtype=dtype).reshape(5,6)
            if dtype != 'f2':
                assert_equal(np.einsum("ij,jk,kl", a, b, c),
                                    a.dot(b).dot(c))
                assert_equal(np.einsum(a, [0,1], b, [1,2], c, [2,3]),
                                    a.dot(b).dot(c))

            d = np.arange(18, dtype=dtype).reshape(3,6)
            np.einsum("ij,jk,kl", a, b, c, out=d,
                                dtype='f8', casting='unsafe')
            assert_equal(d, a.astype('f8').dot(b.astype('f8')
                        ).dot(c.astype('f8')).astype(dtype))
            d[...] = 0
            np.einsum(a, [0,1], b, [1,2], c, [2,3], out=d,
                                dtype='f8', casting='unsafe')
            assert_equal(d, a.astype('f8').dot(b.astype('f8')
                        ).dot(c.astype('f8')).astype(dtype))

            # tensordot(a, b)
            if np.dtype(dtype) != np.dtype('f2'):
                a = np.arange(60, dtype=dtype).reshape(3,4,5)
                b = np.arange(24, dtype=dtype).reshape(4,3,2)
                assert_equal(np.einsum("ijk, jil -> kl", a, b),
                                np.tensordot(a,b, axes=([1,0],[0,1])))
                assert_equal(np.einsum(a, [0,1,2], b, [1,0,3], [2,3]),
                                np.tensordot(a,b, axes=([1,0],[0,1])))

                c = np.arange(10, dtype=dtype).reshape(5,2)
                np.einsum("ijk,jil->kl", a, b, out=c,
                                        dtype='f8', casting='unsafe')
                assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'),
                                        axes=([1,0],[0,1])).astype(dtype))
                c[...] = 0
                np.einsum(a, [0,1,2], b, [1,0,3], [2,3], out=c,
                                        dtype='f8', casting='unsafe')
                assert_equal(c, np.tensordot(a.astype('f8'), b.astype('f8'),
                                        axes=([1,0],[0,1])).astype(dtype))
        finally:
            ctx.__exit__()

        # logical_and(logical_and(a!=0, b!=0), c!=0)
        a = np.array([1,   3,   -2,   0,   12,  13,   0,   1], dtype=dtype)
        b = np.array([0,   3.5, 0.,   -2,  0,   1,    3,   12], dtype=dtype)
        c = np.array([True,True,False,True,True,False,True,True])
        assert_equal(np.einsum("i,i,i->i", a, b, c,
                                dtype='?', casting='unsafe'),
                            np.logical_and(np.logical_and(a!=0, b!=0), c!=0))
        assert_equal(np.einsum(a, [0], b, [0], c, [0], [0],
                                dtype='?', casting='unsafe'),
                            np.logical_and(np.logical_and(a!=0, b!=0), c!=0))

        a = np.arange(9, dtype=dtype)
        assert_equal(np.einsum(",i->", 3, a), 3*np.sum(a))
        assert_equal(np.einsum(3, [], a, [0], []), 3*np.sum(a))
        assert_equal(np.einsum("i,->", a, 3), 3*np.sum(a))
        assert_equal(np.einsum(a, [0], 3, [], []), 3*np.sum(a))

        # Various stride0, contiguous, and SSE aligned variants
        for n in range(1,25):
            a = np.arange(n, dtype=dtype)
            if np.dtype(dtype).itemsize > 1:
                assert_equal(np.einsum("...,...",a,a), np.multiply(a,a))
                assert_equal(np.einsum("i,i", a, a), np.dot(a,a))
                assert_equal(np.einsum("i,->i", a, 2), 2*a)
                assert_equal(np.einsum(",i->i", 2, a), 2*a)
                assert_equal(np.einsum("i,->", a, 2), 2*np.sum(a))
                assert_equal(np.einsum(",i->", 2, a), 2*np.sum(a))

                assert_equal(np.einsum("...,...",a[1:],a[:-1]),
                             np.multiply(a[1:],a[:-1]))
                assert_equal(np.einsum("i,i", a[1:], a[:-1]),
                             np.dot(a[1:],a[:-1]))
                assert_equal(np.einsum("i,->i", a[1:], 2), 2*a[1:])
                assert_equal(np.einsum(",i->i", 2, a[1:]), 2*a[1:])
                assert_equal(np.einsum("i,->", a[1:], 2), 2*np.sum(a[1:]))
                assert_equal(np.einsum(",i->", 2, a[1:]), 2*np.sum(a[1:]))

        # An object array, summed as the data type
        a = np.arange(9, dtype=object)

        b = np.einsum("i->", a, dtype=dtype, casting='unsafe')
        assert_equal(b, np.sum(a))
        assert_equal(b.dtype, np.dtype(dtype))

        b = np.einsum(a, [0], [], dtype=dtype, casting='unsafe')
        assert_equal(b, np.sum(a))
        assert_equal(b.dtype, np.dtype(dtype))

        # A case which was failing (ticket #1885)
        p = np.arange(2) + 1
        q = np.arange(4).reshape(2,2) + 3
        r = np.arange(4).reshape(2,2) + 7
        assert_equal(np.einsum('z,mz,zm->', p, q, r), 253)
예제 #49
0
class TestSolvers(object):
    """Tests inverting a sparse linear system"""

    def setUp(self):
        self.a = spdiags([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]], [0, 1], 5, 5)
        self.b = np.array([1, 2, 3, 4, 5], dtype=np.float64)
        self.b2 = np.array([5, 4, 3, 2, 1], dtype=np.float64)

        self.mgr = WarningManager()
        self.mgr.__enter__()

        warnings.simplefilter("ignore", DeprecationWarning)
        warnings.simplefilter('ignore', SparseEfficiencyWarning)

    def tearDown(self):
        self.mgr.__exit__()

    def test_solve_complex_umfpack(self):
        # Solve with UMFPACK: double precision complex
        a = self.a.astype('D')
        b = self.b
        x = um.spsolve(a, b)
        assert_allclose(a*x, b)

    def test_solve_umfpack(self):
        # Solve with UMFPACK: double precision
        a = self.a.astype('d')
        b = self.b
        x = um.spsolve(a, b)
        assert_allclose(a*x, b)

    def test_solve_sparse_rhs(self):
        # Solve with UMFPACK: double precision, sparse rhs
        a = self.a.astype('d')
        b = csc_matrix(self.b).T
        x = um.spsolve(a, b)
        assert_allclose(a*x, self.b)

    def test_splu_solve(self):
        # Prefactorize (with UMFPACK) matrix for solving with multiple rhs
        a = self.a.astype('d')
        lu = um.splu(a)

        x1 = lu.solve(self.b)
        assert_allclose(a*x1, self.b)
        x2 = lu.solve(self.b2)
        assert_allclose(a*x2, self.b2)

    def test_splu_solve_sparse(self):
        # Prefactorize (with UMFPACK) matrix for solving with multiple rhs
        A = self.a.astype('d')
        lu = um.splu(A)

        b = csc_matrix(self.b.reshape(self.b.shape[0], 1))
        b2 = csc_matrix(self.b2.reshape(self.b2.shape[0], 1))
        B = hstack((b, b2))

        X = lu.solve_sparse(B)
        assert dense_norm(((A*X) - B).todense()) < 1e-14
        assert_allclose((A*X).todense(), B.todense())

    def test_splu_lu(self):
        A = csc_matrix([[1,2,0,4],[1,0,0,1],[1,0,2,1],[2,2,1,0.]])

        lu = um.splu(A)

        Pr = np.zeros((4, 4))
        Pr[lu.perm_r, np.arange(4)] = 1
        Pr = csc_matrix(Pr)
        Pc = np.zeros((4, 4))
        Pc[np.arange(4), lu.perm_c] = 1
        Pc = csc_matrix(Pc)

        R = csc_matrix((4, 4))
        R.setdiag(lu.R)

        A2 = (R * Pr.T * (lu.L * lu.U) * Pc.T).A

        assert_allclose(A2, A.A, atol=1e-13)
예제 #50
0
def safe_eval(source):
    """
    Protected string evaluation.

    Evaluate a string containing a Python literal expression without
    allowing the execution of arbitrary non-literal code.

    Parameters
    ----------
    source : str
        The string to evaluate.

    Returns
    -------
    obj : object
       The result of evaluating `source`.

    Raises
    ------
    SyntaxError
        If the code has invalid Python syntax, or if it contains non-literal
        code.

    Examples
    --------
    >>> np.safe_eval('1')
    1
    >>> np.safe_eval('[1, 2, 3]')
    [1, 2, 3]
    >>> np.safe_eval('{"foo": ("bar", 10.0)}')
    {'foo': ('bar', 10.0)}

    >>> np.safe_eval('import os')
    Traceback (most recent call last):
      ...
    SyntaxError: invalid syntax

    >>> np.safe_eval('open("/home/user/.ssh/id_dsa").read()')
    Traceback (most recent call last):
      ...
    SyntaxError: Unsupported source construct: compiler.ast.CallFunc

    """
    # Local imports to speed up numpy's import time.
    import warnings
    from numpy.testing.utils import WarningManager
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        # compiler package is deprecated for 3.x, which is already solved here
        warnings.simplefilter('ignore', DeprecationWarning)
        try:
            import compiler
        except ImportError:
            import ast as compiler
    finally:
        warn_ctx.__exit__()

    walker = SafeEval()
    try:
        ast = compiler.parse(source, mode="eval")
    except SyntaxError as err:
        raise
    try:
        return walker.visit(ast)
    except SyntaxError as err:
        raise
예제 #51
0
파일: utils.py 프로젝트: manav95/ddfdf
def safe_eval(source):
    """
    Protected string evaluation.

    Evaluate a string containing a Python literal expression without
    allowing the execution of arbitrary non-literal code.

    Parameters
    ----------
    source : str
        The string to evaluate.

    Returns
    -------
    obj : object
       The result of evaluating `source`.

    Raises
    ------
    SyntaxError
        If the code has invalid Python syntax, or if it contains non-literal
        code.

    Examples
    --------
    >>> np.safe_eval('1')
    1
    >>> np.safe_eval('[1, 2, 3]')
    [1, 2, 3]
    >>> np.safe_eval('{"foo": ("bar", 10.0)}')
    {'foo': ('bar', 10.0)}

    >>> np.safe_eval('import os')
    Traceback (most recent call last):
      ...
    SyntaxError: invalid syntax

    >>> np.safe_eval('open("/home/user/.ssh/id_dsa").read()')
    Traceback (most recent call last):
      ...
    SyntaxError: Unsupported source construct: compiler.ast.CallFunc

    """
    # Local imports to speed up numpy's import time.
    import warnings
    from numpy.testing.utils import WarningManager
    warn_ctx = WarningManager()
    warn_ctx.__enter__()
    try:
        # compiler package is deprecated for 3.x, which is already solved here
        warnings.simplefilter('ignore', DeprecationWarning)
        try:
            import compiler
        except ImportError:
            import ast as compiler
    finally:
        warn_ctx.__exit__()

    walker = SafeEval()
    try:
        ast = compiler.parse(source, mode="eval")
    except SyntaxError, err:
        raise