示例#1
0
def test_ScaledArray_dot():
    for N in range(2, 5):
        for P in range(2, 5):
            array = np.random.rand(N, P) + 1
            std = np.diag(1/np.std(array, axis=0))
            mu = np.mean(array, axis=0)
            for K in range(1, 5):
                for squeeze in [True, False]:
                    x = np.random.rand(P, K)
                    if squeeze:
                        x = np.squeeze(x)
                    for fit_x in [x, None]:
                        # With No Scale or Center
                        # x = A'Ax
                        result = array.dot(x)
                        sarray = ScaledCenterArray(scale=False, center=False)
                        sarray.fit(da.array(array), x=fit_x)
                        np.testing.assert_array_almost_equal(result, sarray.dot(x))

                        # With Scale but No Center
                        # B = AD
                        b_array = array.dot(std)
                        result = b_array.dot(x)
                        sarray = ScaledCenterArray(scale=True, center=False)
                        sarray.fit(da.array(array), x=fit_x)
                        np.testing.assert_array_almost_equal(result, sarray.dot(x))

                        # With Center but No Scale:
                        # B = (A - U)
                        b_array = array - mu
                        result = b_array.dot(x)
                        sarray = ScaledCenterArray(scale=False, center=True)
                        sarray.fit(da.array(array), x=fit_x)
                        np.testing.assert_array_almost_equal(result, sarray.dot(x))

                        # With Center and  Scale:
                        # (A - U)'D'D(A - U)x
                        b_array = (array - mu).dot(std)
                        result = b_array.dot(x)
                        sarray = ScaledCenterArray(scale=True, center=True)
                        sarray.fit(da.array(array), x=fit_x)
                        np.testing.assert_array_almost_equal(result, sarray.dot(x))