Ejemplo n.º 1
0
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 2
    yield eq, k.period, 3
    yield eq, k.var, 1

    # Test equality.
    yield eq, EQ().periodic(2), EQ().periodic(2)
    yield neq, EQ().periodic(2), EQ().periodic(3)
    yield neq, Matern12().periodic(2), EQ().periodic(2)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    k = 5 * k.stretch(5)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 10
    yield eq, k.period, 15
    yield eq, k.var, 5

    # Check passing in a list.
    k = EQ().periodic([1, 2])
    yield k, np.random.randn(10, 2)
Ejemplo n.º 2
0
def test_selection():
    k = (2 * EQ().stretch(5)).select(0)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert EQ().select(0) == EQ().select(0)
    assert EQ().select(0) != EQ().select(1)
    assert EQ().select(0) != Matern12().select(0)

    # Standard tests:
    standard_kernel_tests(k)

    # Verify that the kernel has the right properties.
    k = (2 * EQ().stretch(5)).select([2, 3])
    assert k.stationary

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2])
    assert k.stationary

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([1, 2])
    assert k.stationary

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2], [1, 2])
    assert not k.stationary

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([0, 2], [1, 2])
    assert not k.stationary

    # Test computation of the kernel.
    k1 = EQ().select([1, 2])
    k2 = EQ()
    x = B.randn(10, 3)
    approx(k1(x), k2(x[:, [1, 2]]))
Ejemplo n.º 3
0
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    with pytest.raises(RuntimeError):
        k.var
    with pytest.raises(RuntimeError):
        k.period

    def f1(x):
        return x

    def f2(x):
        return x ** 2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:
    standard_kernel_tests(k)

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = np.random.randn(10, 2), np.random.randn(10, 2)

    k2 = k.transform(lambda x: x ** 2)
    k3 = k.transform(lambda x: x ** 2, lambda x: x - 5)

    allclose(k(x1 ** 2, x2 ** 2), k2(x1, x2))
    allclose(k(x1 ** 2, x2 - 5), k3(x1, x2))
Ejemplo n.º 4
0
def test_stretched():
    k = EQ().stretch(2)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 2
    yield eq, k.period, np.inf
    yield eq, k.var, 1

    # Test equality.
    yield eq, EQ().stretch(2), EQ().stretch(2)
    yield neq, EQ().stretch(2), EQ().stretch(3)
    yield neq, EQ().stretch(2), Matern12().stretch(2)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    k = EQ().stretch(1, 2)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield raises, RuntimeError, lambda: k.period
    yield eq, k.var, 1

    # Check passing in a list.
    k = EQ().stretch([1, 2])
    yield k, np.random.randn(10, 2)
Ejemplo n.º 5
0
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    # Verify that the kernel has the right properties.
    assert str(k) == "(EQ() > 2) per 3"
    assert k.stationary

    # Test equality.
    assert EQ().periodic(2) == EQ().periodic(2)
    assert EQ().periodic(2) != EQ().periodic(3)
    assert Matern12().periodic(2) != EQ().periodic(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = 5 * k.stretch(5)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Check passing in a list.
    k = EQ().periodic(np.array([1, 2]))
    k(B.randn(10, 2))

    # Check periodication of a zero.
    k = ZeroKernel()
    assert k.periodic(3) is k
Ejemplo n.º 6
0
def test_exp():
    k = Matern12()

    # Verify that the kernel has the right properties.
    assert k.stationary
    assert k.var == 1
    assert k.length_scale == 1
    assert k.period == np.inf
    assert str(k) == 'Exp()'

    # Test equality.
    assert Matern12() == Matern12()
    assert Matern12() != Linear()

    # Standard tests:
    standard_kernel_tests(k)
Ejemplo n.º 7
0
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    assert k.stationary
    assert k.length_scale == 2
    assert k.period == 3
    assert k.var == 1

    # Test equality.
    assert EQ().periodic(2) == EQ().periodic(2)
    assert EQ().periodic(2) != EQ().periodic(3)
    assert Matern12().periodic(2) != EQ().periodic(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = 5 * k.stretch(5)

    assert k.stationary
    assert k.length_scale == 10
    assert k.period == 15
    assert k.var == 5

    # Check passing in a list.
    k = EQ().periodic(np.array([1, 2]))
    k(np.random.randn(10, 2))
Ejemplo n.º 8
0
def test_stretched():
    k = EQ().stretch(2)

    assert k.stationary
    assert k.length_scale == 2
    assert k.period == np.inf
    assert k.var == 1

    # Test equality.
    assert EQ().stretch(2) == EQ().stretch(2)
    assert EQ().stretch(2) != EQ().stretch(3)
    assert EQ().stretch(2) != Matern12().stretch(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = EQ().stretch(1, 2)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    with pytest.raises(RuntimeError):
        k.period
    assert k.var == 1

    # Check passing in a list.
    k = EQ().stretch(np.array([1, 2]))
    k(np.random.randn(10, 2))
Ejemplo n.º 9
0
def test_basic_arithmetic():
    k1 = EQ()
    k2 = RQ(1e-1)
    k3 = Matern12()
    k4 = Matern32()
    k5 = Matern52()
    k6 = Delta()
    k7 = Linear()
    xs1 = np.random.randn(10, 2), np.random.randn(20, 2)
    xs2 = np.random.randn(), np.random.randn()

    yield ok, allclose(k6(xs1[0]), k6(xs1[0], xs1[0])), 'dispatch'
    yield ok, allclose((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1)), 'prod'
    yield ok, allclose((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2)), 'prod 2'
    yield ok, allclose((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1)), 'sum'
    yield ok, allclose((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2)), 'sum 2'
    yield ok, allclose((5. * k5)(*xs1), 5. * k5(*xs1)), 'prod 3'
    yield ok, allclose((5. * k5)(*xs2), 5. * k5(*xs2)), 'prod 4'
    yield ok, allclose((5. + k7)(*xs1), 5. + k7(*xs1)), 'sum 3'
    yield ok, allclose((5. + k7)(*xs2), 5. + k7(*xs2)), 'sum 4'
    yield ok, allclose(k1.stretch(2.)(*xs1), k1(xs1[0] / 2.,
                                                xs1[1] / 2.)), 'stretch'
    yield ok, allclose(k1.stretch(2.)(*xs2), k1(xs2[0] / 2.,
                                                xs2[1] / 2.)), 'stretch 2'
    yield ok, allclose(
        k1.periodic(1.)(*xs1),
        k1.periodic(1.)(xs1[0], xs1[1] + 5.)), 'periodic'
    yield ok, allclose(
        k1.periodic(1.)(*xs2),
        k1.periodic(1.)(xs2[0], xs2[1] + 5.)), 'periodic 2'
Ejemplo n.º 10
0
def test_input_transform():
    k = Linear().transform(lambda x, c: x - 5)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield raises, RuntimeError, lambda: k.var
    yield raises, RuntimeError, lambda: k.period

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    yield eq, EQ().transform(f1), EQ().transform(f1)
    yield neq, EQ().transform(f1), EQ().transform(f2)
    yield neq, EQ().transform(f1), Matern12().transform(f1)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = np.random.randn(10, 2), np.random.randn(10, 2)

    k2 = k.transform(lambda x, c: x**2)
    k3 = k.transform(lambda x, c: x**2, lambda x, c: x - 5)

    yield assert_allclose, k(x1**2, x2**2), k2(x1, x2)
    yield assert_allclose, k(x1**2, x2 - 5), k3(x1, x2)
Ejemplo n.º 11
0
def test_input_transform():
    k = Linear().transform(lambda x: x - 5)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    def f1(x):
        return x

    def f2(x):
        return x**2

    # Test equality.
    assert EQ().transform(f1) == EQ().transform(f1)
    assert EQ().transform(f1) != EQ().transform(f2)
    assert EQ().transform(f1) != Matern12().transform(f1)

    # Standard tests:
    standard_kernel_tests(k)

    # Test computation of the kernel.
    k = Linear()
    x1, x2 = B.randn(10, 2), B.randn(10, 2)

    k2 = k.transform(lambda x: x**2)
    k3 = k.transform(lambda x: x**2, lambda x: x - 5)

    approx(k(x1**2, x2**2), k2(x1, x2))
    approx(k(x1**2, x2 - 5), k3(x1, x2))
Ejemplo n.º 12
0
def test_derivative():
    # First, check properties.
    k = EQ().diff(0)

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    with pytest.raises(RuntimeError):
        k.var
    with pytest.raises(RuntimeError):
        k.period

    # Test equality.
    assert EQ().diff(0) == EQ().diff(0)
    assert EQ().diff(0) != EQ().diff(1)
    assert Matern12().diff(0) != EQ().diff(0)

    # Standard tests:
    for k in [EQ().diff(0), EQ().diff(None, 0), EQ().diff(0, None)]:
        standard_kernel_tests(k, dtype=tf.float64)

    # Check that a derivative must be specified.
    with pytest.raises(RuntimeError):
        EQ().diff(None, None)(np.array([1.0]))
    with pytest.raises(RuntimeError):
        EQ().diff(None, None).elwise(np.array([1.0]))
Ejemplo n.º 13
0
def test_selection():
    k = (2 * EQ().stretch(5)).select(0)

    assert k.stationary
    assert k.length_scale == 5
    assert k.period == np.inf
    assert k.var == 2

    # Test equality.
    assert EQ().select(0) == EQ().select(0)
    assert EQ().select(0) != EQ().select(1)
    assert EQ().select(0) != Matern12().select(0)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ().stretch(5)).select([2, 3])

    assert k.stationary
    assert k.length_scale == 5
    assert k.period == np.inf
    assert k.var == 2

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2])

    assert k.stationary
    allclose(k.length_scale, [1, 3])
    allclose(k.period, [np.inf, np.inf])
    assert k.var == 2

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([1, 2])

    assert k.stationary
    allclose(k.length_scale, [1, 1])
    allclose(k.period, [2, 3])
    assert k.var == 2

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2], [1, 2])

    assert not k.stationary
    with pytest.raises(RuntimeError):
        k.length_scale
    with pytest.raises(RuntimeError):
        k.period
    assert k.var == 2

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([0, 2], [1, 2])

    assert not k.stationary
    assert k.length_scale == 1
    with pytest.raises(RuntimeError):
        k.period
    assert k.var == 2

    # Test that computation is valid.
    k1 = EQ().select([1, 2])
    k2 = EQ()
    x = np.random.randn(10, 3)
    allclose(k1(x), k2(x[:, [1, 2]]))
Ejemplo n.º 14
0
def test_exp():
    k = Matern12()

    # Verify that the kernel has the right properties.
    yield eq, k.stationary, True
    yield eq, k.var, 1
    yield eq, k.length_scale, 1
    yield eq, k.period, np.inf
    yield eq, str(k), 'Exp()'

    # Test equality.
    yield eq, Matern12(), Matern12()
    yield neq, Matern12(), Linear()

    # Standard tests:
    for x in kernel_generator(k):
        yield x
Ejemplo n.º 15
0
def test_selection():
    k = (2 * EQ().stretch(5)).select(0)

    yield eq, k.stationary, True
    yield eq, k.length_scale, 5
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    # Test equality.
    yield eq, EQ().select(0), EQ().select(0)
    yield neq, EQ().select(0), EQ().select(1)
    yield neq, EQ().select(0), Matern12().select(0)

    # Standard tests:
    for x in kernel_generator(k):
        yield x

    k = (2 * EQ().stretch(5)).select([2, 3])

    yield eq, k.stationary, True
    yield eq, k.length_scale, 5
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    k = (2 * EQ().stretch([1, 2, 3])).select([0, 2])

    yield eq, k.stationary, True
    yield assert_allclose, k.length_scale, [1, 3]
    yield assert_allclose, k.period, [np.inf, np.inf]
    yield eq, k.var, 2

    k = (2 * EQ().periodic([1, 2, 3])).select([1, 2])

    yield eq, k.stationary, True
    yield assert_allclose, k.length_scale, [1, 1]
    yield assert_allclose, k.period, [2, 3]
    yield eq, k.var, 2

    k = (2 * EQ().stretch([1, 2, 3])).select([0, 2], [1, 2])

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield raises, RuntimeError, lambda: k.period
    yield eq, k.var, 2

    k = (2 * EQ().periodic([1, 2, 3])).select([0, 2], [1, 2])

    yield eq, k.stationary, False
    yield eq, k.length_scale, 1
    yield raises, RuntimeError, lambda: k.period
    yield eq, k.var, 2

    # Test that computation is valid.
    k1 = EQ().select([1, 2])
    k2 = EQ()
    x = np.random.randn(10, 3)
    yield assert_allclose, k1(x), k2(x[:, [1, 2]])
Ejemplo n.º 16
0
def test_scaled():
    k = 2 * EQ()

    assert k.stationary

    # Test equality.
    assert 2 * EQ() == 2 * EQ()
    assert 2 * EQ() != 3 * EQ()
    assert 2 * EQ() != 2 * Matern12()

    # Standard tests:
    standard_kernel_tests(k)
Ejemplo n.º 17
0
def test_scaled():
    k = 2 * EQ()

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert 2 * EQ() == 2 * EQ()
    assert 2 * EQ() != 3 * EQ()
    assert 2 * EQ() != 2 * Matern12()

    # Standard tests:
    standard_kernel_tests(k)
Ejemplo n.º 18
0
def test_scaled():
    k = 2 * EQ()

    assert k.stationary
    assert k.length_scale == 1
    assert k.period == np.inf
    assert k.var == 2

    # Test equality.
    assert 2 * EQ() == 2 * EQ()
    assert 2 * EQ() != 3 * EQ()
    assert 2 * EQ() != 2 * Matern12()

    # Standard tests:
    standard_kernel_tests(k)
Ejemplo n.º 19
0
def test_scaled():
    k = 2 * EQ()

    yield eq, k.stationary, True
    yield eq, k.length_scale, 1
    yield eq, k.period, np.inf
    yield eq, k.var, 2

    # Test equality.
    yield eq, 2 * EQ(), 2 * EQ()
    yield neq, 2 * EQ(), 3 * EQ()
    yield neq, 2 * EQ(), 2 * Matern12()

    # Standard tests:
    for x in kernel_generator(k):
        yield x
Ejemplo n.º 20
0
def test_periodic():
    k = EQ().stretch(2).periodic(3)

    assert k.stationary

    # Test equality.
    assert EQ().periodic(2) == EQ().periodic(2)
    assert EQ().periodic(2) != EQ().periodic(3)
    assert Matern12().periodic(2) != EQ().periodic(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = 5 * k.stretch(5)

    assert k.stationary

    # Check passing in a list.
    k = EQ().periodic(np.array([1, 2]))
    k(np.random.randn(10, 2))
Ejemplo n.º 21
0
def test_stretched():
    k = EQ().stretch(2)

    assert k.stationary

    # Test equality.
    assert EQ().stretch(2) == EQ().stretch(2)
    assert EQ().stretch(2) != EQ().stretch(3)
    assert EQ().stretch(2) != Matern12().stretch(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = EQ().stretch(1, 2)

    assert not k.stationary

    # Check passing in a list.
    k = EQ().stretch(np.array([1, 2]))
    k(np.random.randn(10, 2))
Ejemplo n.º 22
0
def test_stretched():
    k = EQ().stretch(2)

    # Verify that the kernel has the right properties.
    assert k.stationary

    # Test equality.
    assert EQ().stretch(2) == EQ().stretch(2)
    assert EQ().stretch(2) != EQ().stretch(3)
    assert EQ().stretch(2) != Matern12().stretch(2)

    # Standard tests:
    standard_kernel_tests(k)

    k = EQ().stretch(1, 2)

    # Verify that the kernel has the right properties.
    assert not k.stationary

    # Check passing in a list.
    k = EQ().stretch(np.array([1, 2]))
    k(B.randn(10, 2))
Ejemplo n.º 23
0
def test_selection():
    k = (2 * EQ().stretch(5)).select(0)

    assert k.stationary

    # Test equality.
    assert EQ().select(0) == EQ().select(0)
    assert EQ().select(0) != EQ().select(1)
    assert EQ().select(0) != Matern12().select(0)

    # Standard tests:
    standard_kernel_tests(k)

    k = (2 * EQ().stretch(5)).select([2, 3])

    assert k.stationary

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2])

    assert k.stationary

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([1, 2])

    assert k.stationary

    k = (2 * EQ().stretch(np.array([1, 2, 3]))).select([0, 2], [1, 2])

    assert not k.stationary

    k = (2 * EQ().periodic(np.array([1, 2, 3]))).select([0, 2], [1, 2])

    assert not k.stationary

    # Test that computation is valid.
    k1 = EQ().select([1, 2])
    k2 = EQ()
    x = np.random.randn(10, 3)
    allclose(k1(x), k2(x[:, [1, 2]]))
Ejemplo n.º 24
0
def test_basic_arithmetic():
    k1 = EQ()
    k2 = RQ(1e-1)
    k3 = Matern12()
    k4 = Matern32()
    k5 = Matern52()
    k6 = Delta()
    k7 = Linear()
    xs1 = B.randn(10, 2), B.randn(20, 2)
    xs2 = B.randn(), B.randn()

    approx(k6(xs1[0]), k6(xs1[0], xs1[0]))
    approx((k1 * k2)(*xs1), k1(*xs1) * k2(*xs1))
    approx((k1 * k2)(*xs2), k1(*xs2) * k2(*xs2))
    approx((k3 + k4)(*xs1), k3(*xs1) + k4(*xs1))
    approx((k3 + k4)(*xs2), k3(*xs2) + k4(*xs2))
    approx((5.0 * k5)(*xs1), 5.0 * k5(*xs1))
    approx((5.0 * k5)(*xs2), 5.0 * k5(*xs2))
    approx((5.0 + k7)(*xs1), 5.0 + k7(*xs1))
    approx((5.0 + k7)(*xs2), 5.0 + k7(*xs2))
    approx(k1.stretch(2.0)(*xs1), k1(xs1[0] / 2.0, xs1[1] / 2.0))
    approx(k1.stretch(2.0)(*xs2), k1(xs2[0] / 2.0, xs2[1] / 2.0))
    approx(k1.periodic(1.0)(*xs1), k1.periodic(1.0)(xs1[0], xs1[1] + 5.0))
    approx(k1.periodic(1.0)(*xs2), k1.periodic(1.0)(xs2[0], xs2[1] + 5.0))
Ejemplo n.º 25
0
def test_derivative():
    # First, check properties.
    k = EQ().diff(0)

    yield eq, k.stationary, False
    yield raises, RuntimeError, lambda: k.length_scale
    yield raises, RuntimeError, lambda: k.var
    yield raises, RuntimeError, lambda: k.period

    # Test equality.
    yield eq, EQ().diff(0), EQ().diff(0)
    yield neq, EQ().diff(0), EQ().diff(1)
    yield neq, Matern12().diff(0), EQ().diff(0)

    yield raises, RuntimeError, lambda: EQ().diff(None, None)(1)

    # Third, check computation.
    B.backend_to_tf()
    s = B.Session()

    # Test derivative of kernel EQ.
    k = EQ()
    x1 = B.array(np.random.randn(10, 1))
    x2 = B.array(np.random.randn(5, 1))

    # Test derivative with respect to first input.
    ref = s.run(-dense(k(x1, x2)) * (x1 - B.transpose(x2)))
    yield assert_allclose, s.run(dense(k.diff(0, None)(x1, x2))), ref
    ref = s.run(-dense(k(x1)) * (x1 - B.transpose(x1)))
    yield assert_allclose, s.run(dense(k.diff(0, None)(x1))), ref

    # Test derivative with respect to second input.
    ref = s.run(-dense(k(x1, x2)) * (B.transpose(x2) - x1))
    yield assert_allclose, s.run(dense(k.diff(None, 0)(x1, x2))), ref
    ref = s.run(-dense(k(x1)) * (B.transpose(x1) - x1))
    yield assert_allclose, s.run(dense(k.diff(None, 0)(x1))), ref

    # Test derivative with respect to both inputs.
    ref = s.run(dense(k(x1, x2)) * (1 - (x1 - B.transpose(x2))**2))
    yield assert_allclose, s.run(dense(k.diff(0, 0)(x1, x2))), ref
    yield assert_allclose, s.run(dense(k.diff(0)(x1, x2))), ref
    ref = s.run(dense(k(x1)) * (1 - (x1 - B.transpose(x1))**2))
    yield assert_allclose, s.run(dense(k.diff(0, 0)(x1))), ref
    yield assert_allclose, s.run(dense(k.diff(0)(x1))), ref

    # Test derivative of kernel Linear.
    k = Linear()
    x1 = B.array(np.random.randn(10, 1))
    x2 = B.array(np.random.randn(5, 1))

    # Test derivative with respect to first input.
    ref = s.run(B.ones((10, 5), dtype=np.float64) * B.transpose(x2))
    yield assert_allclose, s.run(dense(k.diff(0, None)(x1, x2))), ref
    ref = s.run(B.ones((10, 10), dtype=np.float64) * B.transpose(x1))
    yield assert_allclose, s.run(dense(k.diff(0, None)(x1))), ref

    # Test derivative with respect to second input.
    ref = s.run(B.ones((10, 5), dtype=np.float64) * x1)
    yield assert_allclose, s.run(dense(k.diff(None, 0)(x1, x2))), ref
    ref = s.run(B.ones((10, 10), dtype=np.float64) * x1)
    yield assert_allclose, s.run(dense(k.diff(None, 0)(x1))), ref

    # Test derivative with respect to both inputs.
    ref = s.run(B.ones((10, 5), dtype=np.float64))
    yield assert_allclose, s.run(dense(k.diff(0, 0)(x1, x2))), ref
    yield assert_allclose, s.run(dense(k.diff(0)(x1, x2))), ref
    ref = s.run(B.ones((10, 10), dtype=np.float64))
    yield assert_allclose, s.run(dense(k.diff(0, 0)(x1))), ref
    yield assert_allclose, s.run(dense(k.diff(0)(x1))), ref

    s.close()
    B.backend_to_np()