示例#1
0
    def test_einsum_kpmurphy_example(self):
        # code from an email with @murphyk
        N = 2
        C = 3
        D = 4
        K = 5
        T = 6
        r = rng()
        S = r.randn(N, T, K)
        W = r.randn(K, D)
        V = r.randn(D, C)
        L = onp.zeros((N, C))
        for n in range(N):
            for c in range(C):
                s = 0
                for d in range(D):
                    for k in range(K):
                        for t in range(T):
                            s += S[n, t, k] * W[k, d] * V[d, c]
                L[n, c] = s

        path = np.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
        self.assertAllClose(L,
                            np.einsum('ntk,kd,dc->nc', S, W, V, optimize=path),
                            check_dtypes=False)
示例#2
0
    def test_einsum_kpmurphy_example(self):
        # code from an email with @murphyk
        N = 2
        C = 3
        D = 4
        K = 5
        T = 6
        r = self.rng()
        S = r.randn(N, T, K)
        W = r.randn(K, D)
        V = r.randn(D, C)
        L = np.zeros((N, C))
        for n in range(N):
            for c in range(C):
                s = 0
                for d in range(D):
                    for k in range(K):
                        for t in range(T):
                            s += S[n, t, k] * W[k, d] * V[d, c]
                L[n, c] = s

        path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
        rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
        self.assertAllClose(L,
                            jnp.einsum('ntk,kd,dc->nc', S, W, V,
                                       optimize=path),
                            check_dtypes=False,
                            rtol=rtol)
示例#3
0
def compute_elim_order(dag, params):
    # compute optimal elimination order assuming no nodes are observed
    evidence = {}
    cardinality = {name: np.shape(CPT)[0] for name, CPT in params.items()}
    evectors = make_evidence_vectors(cardinality, evidence)
    str = make_einsum_string(dag)
    factors = make_list_of_factors(dag, params, evectors)
    nnodes = len(dag.keys())
    #print('computing elimination order for DAG with {} nodes'.format(nnodes))
    #elim_order = np.einsum_path(str, *factors, optimize='optimal')[0]
    elim_order = np.einsum_path(str, *factors, optimize='greedy')[0]
    return elim_order
示例#4
0
L = np.zeros((N, C))
for n in range(N):
    for c in range(C):
        s = 0
        for d in range(D):
            for k in range(K):
                for t in range(T):
                    s += S[n, t, k] * W[k, d] * V[d, c]
        L[n, c] = s
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V))

path = np.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, np.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

import jax.numpy as jnp
path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
assert np.allclose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path))

# Use full student network from KOller and Friedman
str = 'c,dc,gdi,si,lg,jls,hgj->'
K = 5
cptC = np.random.randn(K)
cptD = np.random.randn(K, K)
cptG = np.random.randn(K, K, K)
cptS = np.random.randn(K, K)
cptL = np.random.randn(K, K)
cptJ = np.random.randn(K, K, K)
cptH = np.random.randn(K, K, K)
cpts = [cptC, cptD, cptG, cptS, cptL, cptJ, cptH]
path_info = np.einsum_path(str, *cpts, optimize='optimal')
print(path_info[0]
示例#5
0
def einsum_path(subscripts, *operands, optimize='greedy'):
  operands = tuple((a.value if isinstance(a, JaxArray) else a) for a in operands)
  return JaxArray(jnp.einsum_path(subscripts, *operands, optimize))