def test_dtype_kron_promotion(): kron = Kronecker(B.ones(int, 5, 5), B.ones(int, 5, 5)) assert B.dtype(kron) == np.int64 kron = Kronecker(B.ones(float, 5, 5), B.ones(int, 5, 5)) assert B.dtype(kron) == np.float64 kron = Kronecker(B.ones(int, 5, 5), B.ones(float, 5, 5)) assert B.dtype(kron) == np.float64
def test_kronecker_formatting(): left = Diagonal(B.ones(2)) right = Diagonal(B.ones(3)) assert (str(Kronecker( left, right)) == "<Kronecker product: shape=6x6, dtype=float64>") assert (repr(Kronecker( left, right)) == "<Kronecker product: shape=6x6, dtype=float64\n" " left=<diagonal matrix: shape=2x2, dtype=float64\n" " diag=[1. 1.]>\n" " right=<diagonal matrix: shape=3x3, dtype=float64\n" " diag=[1. 1. 1.]>>")
def construct_model(vs): if args.separable: # Copy same kernel `m` times. kernel = [ Mat52().stretch(vs.bnd(6 * 30, lower=60, name="k_scale")) ] kernels = kernel * m else: # Parametrise different kernels. kernels = [ Mat52().stretch(vs.bnd(6 * 30, lower=60, name=f"{i}/k_scale")) for i in range(m) ] noise = vs.bnd(1e-2, name="noise") latent_noises = vs.bnd(1e-2 * B.ones(m), name="latent_noises") # Construct component of the mixing matrix over simulators. u = vs.orth(init=u_s_init, shape=(p_s, m_s), name="sims/u") s_sqrt = vs.bnd(init=s_sqrt_s_init, shape=(m_s, ), name="sims/s_sqrt") u_s = Dense(u) s_sqrt_s = Diagonal(s_sqrt) # Construct components of the mixing matrix over space from a # covariance. scales = vs.bnd(init=scales_init, name="space/scales") k = Mat52().stretch(scales) u, s, _ = B.svd(B.dense(k(loc))) u_r = Dense(u[:, :m_r]) s_sqrt_r = Diagonal(B.sqrt(s[:m_r])) # Compose. s_sqrt = Kronecker(s_sqrt_s, s_sqrt_r) u = Kronecker(u_s, u_r) return OILMM(kernels, u, s_sqrt, noise, latent_noises)
def pinv(a: Kronecker): return Kronecker(B.pinv(a.left), B.pinv(a.right))
def test_kronecker_attributes(): left = Diagonal(B.ones(2)) right = Diagonal(B.ones(3)) kron = Kronecker(left, right) assert kron.left is left assert kron.right is right
def kron_mixed(request): code1, code2 = request.param return Kronecker(generate(code1), generate(code2))