def test_sum(): m1 = TensorProductMean(f1) m2 = TensorProductMean(f2) # Test equality. assert m1 + m2 == m1 + m2 assert m1 + m2 == m2 + m1 assert m1 + m2 != ZeroMean() + m2 assert m1 + m2 != m1 + ZeroMean()
def test_function_mean(): m1 = 5 * OneMean() + (lambda x: x**2) m2 = (lambda x: x**2) + 5 * OneMean() m3 = (lambda x: x**2) + ZeroMean() m4 = ZeroMean() + (lambda x: x**2) x = np.random.randn(10, 1) yield ok, np.allclose(m1(x), 5 + x**2) yield ok, np.allclose(m2(x), 5 + x**2) yield ok, np.allclose(m3(x), x**2) yield ok, np.allclose(m4(x), x**2) def my_function(x): pass yield eq, str(TensorProductMean(my_function)), 'my_function'
def test_basic_arithmetic(): dispatch = Dispatcher() @dispatch(Number) def f1(x): return np.array([[x**2]]) @dispatch(object) def f1(x): return np.sum(x**2, axis=1)[:, None] @dispatch(Number) def f2(x): return np.array([[x**3]]) @dispatch(object) def f2(x): return np.sum(x**3, axis=1)[:, None] m1 = TensorProductMean(f1) m2 = TensorProductMean(f2) m3 = ZeroMean() x1 = np.random.randn(10, 2) x2 = np.random.randn() yield ok, np.allclose((m1 * m2)(x1), m1(x1) * m2(x1)), 'prod' yield ok, np.allclose((m1 * m2)(x2), m1(x2) * m2(x2)), 'prod 2' yield ok, np.allclose((m1 + m3)(x1), m1(x1) + m3(x1)), 'sum' yield ok, np.allclose((m1 + m3)(x2), m1(x2) + m3(x2)), 'sum 2' yield ok, np.allclose((5. * m1)(x1), 5. * m1(x1)), 'prod 3' yield ok, np.allclose((5. * m1)(x2), 5. * m1(x2)), 'prod 4' yield ok, np.allclose((5. + m1)(x1), 5. + m1(x1)), 'sum 3' yield ok, np.allclose((5. + m1)(x2), 5. + m1(x2)), 'sum 4'
def test_tensor_product(): m1 = 5 * OneMean() + (lambda x: x ** 2) m2 = (lambda x: x ** 2) + 5 * OneMean() m3 = (lambda x: x ** 2) + ZeroMean() m4 = ZeroMean() + (lambda x: x ** 2) x = B.randn(10, 1) assert np.allclose(m1(x), 5 + x ** 2) assert np.allclose(m2(x), 5 + x ** 2) assert np.allclose(m3(x), x ** 2) assert np.allclose(m4(x), x ** 2) def my_function(x): pass assert str(TensorProductMean(my_function)) == "my_function"
def test_derivative(): yield eq, str(EQ().diff(0)), 'd(0) EQ()' yield eq, str(EQ().diff(0, 1)), 'd(0, 1) EQ()' yield eq, str(ZeroKernel().diff(0)), '0' yield eq, str(OneKernel().diff(0)), '0' yield eq, str(ZeroMean().diff(0)), '0' yield eq, str(OneMean().diff(0)), '0'
def test_derivative(): assert str(EQ().diff(0)) == 'd(0) EQ()' assert str(EQ().diff(0, 1)) == 'd(0, 1) EQ()' assert str(ZeroKernel().diff(0)) == '0' assert str(OneKernel().diff(0)) == '0' assert str(ZeroMean().diff(0)) == '0' assert str(OneMean().diff(0)) == '0'
def test_basic_arithmetic(): m1 = TensorProductMean(f1) m2 = TensorProductMean(f2) m3 = ZeroMean() x1 = B.randn(10, 2) x2 = B.randn() approx((m1 * m2)(x1), m1(x1) * m2(x1)) approx((m1 * m2)(x2), m1(x2) * m2(x2)) approx((m1 + m3)(x1), m1(x1) + m3(x1)) approx((m1 + m3)(x2), m1(x2) + m3(x2)) approx((5.0 * m1)(x1), 5.0 * m1(x1)) approx((5.0 * m1)(x2), 5.0 * m1(x2)) approx((5.0 + m1)(x1), 5.0 + m1(x1)) approx((5.0 + m1)(x2), 5.0 + m1(x2))
def test_shifting(): # Kernels: yield eq, str(ZeroKernel().shift(5)), '0' yield eq, str(EQ().shift(5)), 'EQ()' yield eq, str(Linear().shift(5)), 'Linear() shift 5' yield eq, str((5 * EQ()).shift(5)), '5 * EQ()' yield eq, str((5 * Linear()).shift(5)), '(5 * Linear()) shift 5' # Means: def mean(x): return x m = TensorProductMean(mean) yield eq, str(ZeroMean().shift(5)), '0' yield eq, str(m.shift(5)), 'mean shift 5' yield eq, str(m.shift(5).shift(5)), 'mean shift 10' yield eq, str((5 * m).shift(5)), '(5 * mean) shift 5'
def test_shifting(): # Kernels: assert str(ZeroKernel().shift(5)) == '0' assert str(EQ().shift(5)) == 'EQ()' assert str(Linear().shift(5)) == 'Linear() shift 5' assert str((5 * EQ()).shift(5)) == '5 * EQ()' assert str((5 * Linear()).shift(5)) == '(5 * Linear()) shift 5' # Means: def mean(x): return x m = TensorProductMean(mean) assert str(ZeroMean().shift(5)) == '0' assert str(m.shift(5)) == 'mean shift 5' assert str(m.shift(5).shift(5)) == 'mean shift 10' assert str((5 * m).shift(5)) == '(5 * mean) shift 5'
def test_ones_zeros(): c = Cache() # Nothing to check for kernels: ones and zeros are represented in a # structured way. # Test that ones and zeros are cached and that all signatures work. m = ZeroMean() yield eq, id(m(np.random.randn(10, 10), c)), \ id(m(np.random.randn(10, 10), c)) yield neq, id(m(np.random.randn(10, 10), c)), \ id(m(np.random.randn(5, 10), c)) yield eq, id(m(1, c)), id(m(1, c)) m = OneMean() yield eq, id(m(np.random.randn(10, 10), c)), \ id(m(np.random.randn(10, 10), c)) yield neq, id(m(np.random.randn(10, 10), c)), \ id(m(np.random.randn(5, 10), c)) yield eq, id(m(1, c)), id(m(1, c))