Exemplo n.º 1
0
def test_special_case_numpy_functions():
    '''
    Test a couple of functions/methods that need special treatment.
    '''
    from brian2.units.unitsafefunctions import ravel, diagonal, trace, dot, where
    
    quadratic_matrix = np.reshape(np.arange(9), (3, 3)) * mV

    # Temporarily suppress warnings related to the matplotlib 1.3 bug
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        # Check that function and method do the same thing
        assert_equal(ravel(quadratic_matrix), quadratic_matrix.ravel())
        # Check that function gives the same result as on unitless arrays
        assert_equal(np.asarray(ravel(quadratic_matrix)),
                     ravel(np.asarray(quadratic_matrix)))
        # Check that the function gives the same results as the original numpy
        # function
        assert_equal(np.ravel(np.asarray(quadratic_matrix)),
                     ravel(np.asarray(quadratic_matrix)))

    # Do the same checks for diagonal, trace and dot
    assert_equal(diagonal(quadratic_matrix), quadratic_matrix.diagonal())
    assert_equal(np.asarray(diagonal(quadratic_matrix)),
                 diagonal(np.asarray(quadratic_matrix)))
    assert_equal(np.diagonal(np.asarray(quadratic_matrix)),
                 diagonal(np.asarray(quadratic_matrix)))

    assert_equal(trace(quadratic_matrix), quadratic_matrix.trace())
    assert_equal(np.asarray(trace(quadratic_matrix)),
                 trace(np.asarray(quadratic_matrix)))
    assert_equal(np.trace(np.asarray(quadratic_matrix)),
                 trace(np.asarray(quadratic_matrix)))

    assert_equal(dot(quadratic_matrix, quadratic_matrix),
                 quadratic_matrix.dot(quadratic_matrix))
    assert_equal(np.asarray(dot(quadratic_matrix, quadratic_matrix)),
                 dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)))
    assert_equal(np.dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)),
                 dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)))
    
    assert_equal(np.asarray(quadratic_matrix.prod()),
                 np.asarray(quadratic_matrix).prod())
    assert_equal(np.asarray(quadratic_matrix.prod(axis=0)),
                 np.asarray(quadratic_matrix).prod(axis=0))
        
    # Check for correct units
    if use_matplotlib_units_fix:
        assert have_same_dimensions(1, ravel(quadratic_matrix))
    else:
        assert have_same_dimensions(quadratic_matrix, ravel(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix, trace(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix, diagonal(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix[0] ** 2,
                                dot(quadratic_matrix, quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix.prod(axis=0),
                                quadratic_matrix[0] ** quadratic_matrix.shape[0])
    
    # check the where function
    # pure numpy array
    cond = [True, False, False]
    ar1 = np.array([1, 2, 3])
    ar2 = np.array([4, 5, 6])
    assert_equal(np.where(cond), where(cond))
    assert_equal(np.where(cond, ar1, ar2), where(cond, ar1, ar2))
    
    # dimensionless quantity
    assert_equal(np.where(cond, ar1, ar2),
                 np.asarray(where(cond, ar1 * mV/mV, ar2 * mV/mV)))
    
    # quantity with dimensions
    ar1 = ar1 * mV
    ar2 = ar2 * mV
    assert_equal(np.where(cond, np.asarray(ar1), np.asarray(ar2)),
                 np.asarray(where(cond, ar1, ar2)))    
    
    # Check some error cases
    assert_raises(ValueError, lambda: where(cond, ar1))
    assert_raises(TypeError, lambda: where(cond, ar1, ar1, ar2))
    assert_raises(DimensionMismatchError, lambda: where(cond, ar1, ar1 / ms))

    # Check setasflat (for numpy < 1.7)
    if hasattr(Quantity, 'setasflat'):
        a = np.arange(10) * mV
        b = np.ones(10).reshape(5, 2) * volt
        c = np.ones(10).reshape(5, 2) * second
        assert_raises(DimensionMismatchError, lambda: a.setasflat(c))
        a.setasflat(b)
        assert_equal(a.flatten(), b.flatten())

    # Check cumprod
    a = np.arange(1, 10) * mV/mV
    assert_equal(a.cumprod(), np.asarray(a).cumprod())
    assert_raises(TypeError, lambda: (np.arange(1, 5)*mV).cumprod())
Exemplo n.º 2
0
def test_special_case_numpy_functions():
    '''
    Test a couple of functions/methods that need special treatment.
    '''
    from brian2.units.unitsafefunctions import ravel, diagonal, trace, dot, where

    quadratic_matrix = np.reshape(np.arange(9), (3, 3)) * mV

    # Temporarily suppress warnings related to the matplotlib 1.3 bug
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        # Check that function and method do the same thing
        assert_equal(ravel(quadratic_matrix), quadratic_matrix.ravel())
        # Check that function gives the same result as on unitless arrays
        assert_equal(np.asarray(ravel(quadratic_matrix)),
                     ravel(np.asarray(quadratic_matrix)))
        # Check that the function gives the same results as the original numpy
        # function
        assert_equal(np.ravel(np.asarray(quadratic_matrix)),
                     ravel(np.asarray(quadratic_matrix)))

    # Do the same checks for diagonal, trace and dot
    assert_equal(diagonal(quadratic_matrix), quadratic_matrix.diagonal())
    assert_equal(np.asarray(diagonal(quadratic_matrix)),
                 diagonal(np.asarray(quadratic_matrix)))
    assert_equal(np.diagonal(np.asarray(quadratic_matrix)),
                 diagonal(np.asarray(quadratic_matrix)))

    assert_equal(trace(quadratic_matrix), quadratic_matrix.trace())
    assert_equal(np.asarray(trace(quadratic_matrix)),
                 trace(np.asarray(quadratic_matrix)))
    assert_equal(np.trace(np.asarray(quadratic_matrix)),
                 trace(np.asarray(quadratic_matrix)))

    assert_equal(dot(quadratic_matrix, quadratic_matrix),
                 quadratic_matrix.dot(quadratic_matrix))
    assert_equal(
        np.asarray(dot(quadratic_matrix, quadratic_matrix)),
        dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)))
    assert_equal(
        np.dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)),
        dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)))

    assert_equal(np.asarray(quadratic_matrix.prod()),
                 np.asarray(quadratic_matrix).prod())
    assert_equal(np.asarray(quadratic_matrix.prod(axis=0)),
                 np.asarray(quadratic_matrix).prod(axis=0))

    # Check for correct units
    if use_matplotlib_units_fix:
        assert have_same_dimensions(1, ravel(quadratic_matrix))
    else:
        assert have_same_dimensions(quadratic_matrix, ravel(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix, trace(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix, diagonal(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix[0]**2,
                                dot(quadratic_matrix, quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix.prod(axis=0),
                                quadratic_matrix[0]**quadratic_matrix.shape[0])

    # check the where function
    # pure numpy array
    cond = [True, False, False]
    ar1 = np.array([1, 2, 3])
    ar2 = np.array([4, 5, 6])
    assert_equal(np.where(cond), where(cond))
    assert_equal(np.where(cond, ar1, ar2), where(cond, ar1, ar2))

    # dimensionless quantity
    assert_equal(np.where(cond, ar1, ar2),
                 np.asarray(where(cond, ar1 * mV / mV, ar2 * mV / mV)))

    # quantity with dimensions
    ar1 = ar1 * mV
    ar2 = ar2 * mV
    assert_equal(np.where(cond, np.asarray(ar1), np.asarray(ar2)),
                 np.asarray(where(cond, ar1, ar2)))

    # Check some error cases
    with pytest.raises(ValueError):
        where(cond, ar1)
    with pytest.raises(TypeError):
        where(cond, ar1, ar1, ar2)
    with pytest.raises(DimensionMismatchError):
        where(cond, ar1, ar1 / ms)

    # Check setasflat (for numpy < 1.7)
    if hasattr(Quantity, 'setasflat'):
        a = np.arange(10) * mV
        b = np.ones(10).reshape(5, 2) * volt
        c = np.ones(10).reshape(5, 2) * second
        with pytest.raises(DimensionMismatchError):
            a.setasflat(c)
        a.setasflat(b)
        assert_equal(a.flatten(), b.flatten())

    # Check cumprod
    a = np.arange(1, 10) * mV / mV
    assert_equal(a.cumprod(), np.asarray(a).cumprod())
    with pytest.raises(TypeError):
        (np.arange(1, 5) * mV).cumprod()
Exemplo n.º 3
0
def test_special_case_numpy_functions():
    '''
    Test a couple of functions/methods that need special treatment.
    '''
    from brian2.units.unitsafefunctions import ravel, diagonal, trace, dot, where
    
    quadratic_matrix = np.reshape(np.arange(9), (3, 3)) * mV
    # Check that function and method do the same thing
    assert_equal(ravel(quadratic_matrix), quadratic_matrix.ravel())
    # Check that function gives the same result as on unitless arrays
    assert_equal(np.asarray(ravel(quadratic_matrix)),
                 ravel(np.asarray(quadratic_matrix)))
    # Check that the function gives the same results as the original numpy
    # function
    assert_equal(np.ravel(np.asarray(quadratic_matrix)),
                 ravel(np.asarray(quadratic_matrix)))

    # Do the same checks for diagonal, trace and dot
    assert_equal(diagonal(quadratic_matrix), quadratic_matrix.diagonal())
    assert_equal(np.asarray(diagonal(quadratic_matrix)),
                 diagonal(np.asarray(quadratic_matrix)))
    assert_equal(np.diagonal(np.asarray(quadratic_matrix)),
                 diagonal(np.asarray(quadratic_matrix)))

    assert_equal(trace(quadratic_matrix), quadratic_matrix.trace())
    assert_equal(np.asarray(trace(quadratic_matrix)),
                 trace(np.asarray(quadratic_matrix)))
    assert_equal(np.trace(np.asarray(quadratic_matrix)),
                 trace(np.asarray(quadratic_matrix)))

    assert_equal(dot(quadratic_matrix, quadratic_matrix),
                 quadratic_matrix.dot(quadratic_matrix))
    assert_equal(np.asarray(dot(quadratic_matrix, quadratic_matrix)),
                 dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)))
    assert_equal(np.dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)),
                 dot(np.asarray(quadratic_matrix), np.asarray(quadratic_matrix)))
    
    assert_equal(np.asarray(quadratic_matrix.prod()),
                 np.asarray(quadratic_matrix).prod())
    assert_equal(np.asarray(quadratic_matrix.prod(axis=0)),
                 np.asarray(quadratic_matrix).prod(axis=0))
        
    # Check for correct units
    assert have_same_dimensions(quadratic_matrix, ravel(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix, trace(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix, diagonal(quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix[0] ** 2,
                                dot(quadratic_matrix, quadratic_matrix))
    assert have_same_dimensions(quadratic_matrix.prod(axis=0),
                                quadratic_matrix[0] ** quadratic_matrix.shape[0])
    
    # check the where function
    # pure numpy array
    cond = [True, False, False]
    ar1 = np.array([1, 2, 3])
    ar2 = np.array([4, 5, 6])
    assert_equal(np.where(cond), where(cond))
    assert_equal(np.where(cond, ar1, ar2), where(cond, ar1, ar2))
    
    # dimensionless quantity
    assert_equal(np.where(cond, ar1, ar2),
                 np.asarray(where(cond, ar1 * mV/mV, ar2 * mV/mV)))
    
    # quantity with dimensions
    ar1 = ar1 * mV
    ar2 = ar2 * mV
    assert_equal(np.where(cond, np.asarray(ar1), np.asarray(ar2)),
                 np.asarray(where(cond, ar1, ar2)))    
    
    # Check some error cases
    assert_raises(ValueError, lambda: where(cond, ar1))
    assert_raises(TypeError, lambda: where(cond, ar1, ar1, ar2))
    assert_raises(DimensionMismatchError, lambda: where(cond, ar1, ar1 / ms))