コード例 #1
0
def create_fns(input, in_signs, Ds):

    cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])])
    
    Ws = [sj.initializers.he((j, i)) for j, i in zip(Ds[1:], Ds[:-1])]
    bs = [sj.initializers.he((j,)) for j in Ds[1:]]

    A_w = [T.eye(Ds[0])]
    B_w = [T.zeros(Ds[0])]
    
    A_q = [T.eye(Ds[0])]
    B_q = [T.zeros(Ds[0])]
    
    maps = [input]
    signs = []
    masks = [T.ones(Ds[0])]
    in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1.,
                                     0.1)

    for w, b in zip(Ws[:-1], bs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + b
        signs.append(T.sign(pre_activation))
        masks.append(T.where(pre_activation > 0, 1., 0.1))

        maps.append(pre_activation * masks[-1])

    maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1])

    # compute per region A and B
    for start, end, w, b, m in zip(cumulative_units[:-1],
                                   cumulative_units[1:], Ws, bs, masks):

        A_w.append(T.matmul(w * m, A_w[-1]))
        B_w.append(T.matmul(w * m, B_w[-1]) + b)

        A_q.append(T.matmul(w * in_masks[start:end], A_q[-1]))
        B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b)

    signs = T.concatenate(signs)
    ineq_b = T.concatenate(B_w[1:-1])
    ineq_A = T.vstack(A_w[1:-1])

    inequalities = T.hstack([ineq_b[:, None], ineq_A])
    inequalities = inequalities * signs[:, None] / T.linalg.norm(ineq_A, 2,
                                                         1, keepdims=True)

    inequalities_code = T.hstack([T.concatenate(B_q[1:-1])[:, None],
                                  T.vstack(A_q[1:-1])])
    inequalities_code = inequalities_code * in_signs[:, None]

    f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1],
                                    inequalities, signs])
    g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]])
    all_g = sj.function(in_signs, outputs=inequalities_code)
    h = sj.function(input, outputs=maps[-1])

    return f, g, h, all_g
コード例 #2
0
def get_Abs(Ws, vs, Qs):
    
    """produces the pre activation feature maps of layer ell"""

    n = Qs[-1].shape[0]
    As = [Ws[0] * T.ones((n, 1, 1))]
    bs = [vs[0] * T.ones((n, 1))]
    for i in range(len(Qs)):
        As.append(T.einsum('db,nb,nbs->nds', Ws[i + 1], Qs[i], As[-1]))
        bs.append(T.einsum('db,nb,nb->nd', Ws[i + 1], Qs[i], bs[-1])\
                  + vs[i + 1])

    return As, bs
コード例 #3
0
def test_base():
    a = T.ones((10, ))
    b = a.sum()
    print(b.get())
    print(b.get())
    f = symjax.function(outputs=b)
    [f() for i in range(100)]
コード例 #4
0
ファイル: base_test.py プロジェクト: SymJAX/SymJAX
def test_stack():
    u = tt.Variable(tt.ones((2, )))
    output = tt.stack([u, 2 * u, 3 * u])
    f = symjax.function(outputs=output)
    assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2)))
    print(f())
    print(f())
コード例 #5
0
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               GLO=False):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')
    z = T.Variable(T.random.randn((batch_size, Ds[0])))
    logvar_x = T.Variable(T.ones(1))

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    h = [z]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])

    # LOSS
    prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b
    if GLO:
        loss = T.sum((x - h[-1])**2) / batch_size + prior
        variables = Ws + bs
    else:
        loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior
        variables = Ws + bs

    prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\
            + sum([(w**2).sum() for w in Ws], 0.) / cov_W
 
    opti = sj.optimizers.Adam(loss + prior, lr, params=variables)
    infer = sj.optimizers.Adam(loss, lr, params=[z])

    estimate = sj.function(x, outputs=z, updates=infer.updates)
    train = sj.function(x, outputs=loss, updates=opti.updates)
    lossf = sj.function(x, outputs=loss)
    params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)])

    output = {'train': train, 'estimate':estimate, 'params':params}
    output['reset'] = lambda v: z.assign(v)
    if GLO:
        output['model'] = 'GLO'
    else:
        output['model'] = 'HARD'
    output['loss'] = lossf
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}
    return output
コード例 #6
0
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1):

    x = T.Placeholder([batch_size, Ds[-1]], 'float32')

    # ENCODER
    enc = encoder(x, Ds[0])
    mu = enc[-1][:, :Ds[0]]
    logvar = enc[-1][:, Ds[0]:]
    var = T.exp(logvar)
 
    z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0]))
    z_ph = T.Placeholder((batch_size, Ds[0]), 'float32')

    # DECODER
    Ws, bs = init_weights(Ds, seed, scaler)

    Ws = [T.Variable(w) for w in Ws]
    bs = [T.Variable(b) for b in bs]
    logvar_x = T.Variable(T.zeros(1), name='logvar_x') 
    var_x = T.exp(logvar_x)

    h, h_ph = [z], [z_ph]
    for w, b in zip(Ws[:-1], bs[:-1]):
        h.append(T.matmul(h[-1], w.transpose()) + b)
        h.append(h[-1] * relu_mask(h[-1], leakiness))
        h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b)
        h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness))

    h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1])
    h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1])

    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b 
    kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1)
    px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1)
    loss = - (px + kl).mean() + prior

    variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x]
    opti = sj.optimizers.Adam(loss, lr, params=variables)

    train = sj.function(x, outputs=loss, updates=opti.updates)
    g = sj.function(z_ph, outputs=h_ph[-1])
    params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])])
    get_varx = sj.function(outputs = var_x)


    output = {'train': train, 'g':g, 'params':params}
    output['model'] = 'VAE'
    output['varx'] = get_varx
    output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler,
                    'prior': sj.function(outputs=prior)}
    def sample(n):
        samples = []
        for i in range(n // batch_size):
            samples.append(g(np.random.randn(batch_size, Ds[0])))
        return np.concatenate(samples)
    output['sample'] = sample
    return output
コード例 #7
0
    def forward(self, input, crop_shape, deterministic, padding=0, seed=None):

        self.crop_shape = crop_shape
        # if given only a scalar
        if not hasattr(padding, "__len__"):
            self.pad_shape = [(padding, padding)] * (input.shape - 1)
        # else
        else:
            self.pad_shape = [(pad,
                               pad) if not hasattr(pad, "__len__") else pad
                              for pad in padding]

        assert len(self.pad_shape) == len(self.crop_shape)
        assert len(self.pad_shape) == (len(input.shape) - 1)

        self.start_indices = list()
        self.fixed_indices = list()
        for i, (pad, dim, crop) in enumerate(
                zip(self.pad_shape, input.shape[1:], self.crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            assert maxval >= 0
            self.start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            self.fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        self.start_indices = T.concatenate(self.start_indices, 1)
        self.fixed_indices = T.concatenate(self.fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + self.pad_shape)

        routput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.start_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )
        doutput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.fixed_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )

        return doutput * dirac + (1 - dirac) * routput
コード例 #8
0
def test_cond3():
    sj.current_graph().reset()
    v = T.ones((10, 10)) * 3
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda a, u: a * u,
        lambda a, u: a + u,
        true_inputs=(
            2 * T.ones((10, 10)),
            v,
        ),
        false_inputs=(
            2 * T.ones((10, 10)),
            v,
        ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 6 * np.ones((10, 10)))
    assert np.array_equal(f(0), 5 * np.ones((10, 10)))
コード例 #9
0
def test_cond2():
    sj.current_graph().reset()
    v = T.ones((10, 10))
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda u: 4 * u,
        lambda u: u,
        true_inputs=(v, ),
        false_inputs=(2 * v, ),
    )
    f = sj.function(u, outputs=out)
    assert np.array_equal(f(1), 4 * np.ones((10, 10)))
    assert np.array_equal(f(0), 2 * np.ones((10, 10)))
コード例 #10
0
ファイル: layers.py プロジェクト: SymJAX/SymJAX
    def __init__(self, input, crop_shape, deterministic, padding=0, seed=None):

        # if given only a scalar
        if not hasattr(padding, "__len__"):
            pad_shape = [(padding, padding)] * (input.ndim - 1)
        # else
        else:
            pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad
                         for pad in padding]

        assert len(pad_shape) == len(crop_shape)
        assert len(pad_shape) == input.ndim - 1

        start_indices = list()
        fixed_indices = list()
        for i, (pad, dim,
                crop) in enumerate(zip(pad_shape, input.shape[1:],
                                       crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        start_indices = T.concatenate(start_indices, 1)
        fixed_indices = T.concatenate(fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + pad_shape)

        routput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, start_indices],
        )
        doutput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, fixed_indices],
        )

        return doutput * dirac + (1 - dirac) * routput
コード例 #11
0
def get_forward(Ws, Qs):
    """this function gives the slope matrix that forwards any pre activation
    of layer l to the output layer which is ::
        W^{L}Q^{L}_{\omega}W^{\ell}Q^{\ell}
    for the \ell element of the returned list. For the first one, is returns
    the entire A_{\omega} and for the last one it is the identity matrix
    """
    N = Qs[-1].shape[0]
    L = len(Qs)

    forward = [T.identity(Ws[-1].shape[0]) * T.ones((N, 1, 1))]
    for i in range(L):
        forward.append(T.einsum('ndb,bs,ns->nds', forward[-1], Ws[- 1 - i],
                                Qs[- 1 - i]))

    return forward[::-1]
コード例 #12
0
ファイル: layers.py プロジェクト: hardmaru/SymJAX
    def __init__(self,
                 input_or_shape,
                 crop_shape,
                 deterministic,
                 padding=0,
                 seed=None):

        self.init_input(input_or_shape)
        self.crop_shape = crop_shape
        # if given only a scalar
        if not hasattr(padding, '__len__'):
            self.pad_shape = [(padding, padding)] * (self.input.shape - 1)
        # else
        else:
            self.pad_shape = [(pad,
                               pad) if not hasattr(pad, '__len__') else pad
                              for pad in padding]

        assert len(self.pad_shape) == len(self.crop_shape)
        assert len(self.pad_shape) == (len(self.input.shape) - 1)

        self.deterministic = deterministic

        self.start_indices = list()
        self.fixed_indices = list()
        for i, (pad, dim, crop) in enumerate(
                zip(self.pad_shape, self.input.shape[1:], self.crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            assert maxval >= 0
            self.start_indices.append(
                T.random.randint(minval=0,
                                 maxval=maxval,
                                 shape=(self.input.shape[0], 1),
                                 dtype='int32',
                                 seed=seed + i if seed is not None else seed))

            self.fixed_indices.append(
                T.ones((self.input.shape[0], 1), 'int32') * (maxval // 2))
        self.start_indices = T.concatenate(self.start_indices, 1)
        self.fixed_indices = T.concatenate(self.fixed_indices, 1)

        super().__init__(self.forward(self.input))
コード例 #13
0
def test_cond5():
    sj.current_graph().reset()
    v = T.ones((10, 10)) * 3
    W = T.Variable(1)
    u = T.Placeholder((), "int32")
    out = T.cond(
        u > 0,
        lambda a, u: a * u[0],
        lambda a, u: a + u[1],
        true_inputs=(
            W,
            v,
        ),
        false_inputs=(
            W,
            v,
        ),
    )
    f = sj.function(u, outputs=out, updates={W: W + 1})
    assert np.array_equal(f(1), 3 * np.ones(10))
    assert np.array_equal(f(0), 5 * np.ones(10))
    assert np.array_equal(f(1), 9 * np.ones(10))
コード例 #14
0
import tensorflow_probability as tfp

tfp = tfp.experimental.substrates.jax

nn = tfp.distributions.Normal(1.0, 5.0)

print(nn.cdf(1))

mean = T.Placeholder((1, ), "float32", name="mean")


def inst(self, instance):
    print("instancecheck", self, instance)
    return True


upgrade_class(mean, T.Tensor, tf.Variable)

normal_dist = T.wrap_class(tfp.distributions.Normal)
a = normal_dist(mean, 5.0)

print(a._loc is mean)
asdf
x = T.Variable(T.ones(1) - 5)
output = a.cdf(x)

get_f = sj.function(mean, outputs=output)

for i in range(-5, 5):
    print("mean:", 1, " x:", T.get(x), " cdf:", get_f(i))
コード例 #15
0
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               var_x=1):

    alpha = T.Placeholder((1,), 'float32')
    x = T.Placeholder((Ds[0],), 'float32')
    X = T.Placeholder((batch_size, Ds[-1]), 'float32')

    signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32')
    SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32')
    
    m0 = T.Placeholder((batch_size, R), 'float32')
    m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32')
    m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32')

    Ws, vs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)]
    vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)]

    var_x = T.Variable(T.ones(Ds[-1]) * var_x)
    var_z = T.Variable(T.ones(Ds[0]))

    # create the placeholders
    Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws]
    vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs]
    var_x_ph = T.Placeholder(var_x.shape, var_x.dtype)

    ############################################################################
    # Compute the output of g(x)
    ############################################################################

    maps = [x]
    xsigns = []
    masks = []
    
    for w, v in zip(Ws[:-1], vs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + v
        xsigns.append(T.sign(pre_activation))
        masks.append(relu_mask(pre_activation, leakiness))
        maps.append(pre_activation * masks[-1])

    xsigns = T.concatenate(xsigns)
    maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1])

    ############################################################################
    # compute the masks and then the per layer affine mappings
    ############################################################################

    cumulative_units = np.cumsum([0] + Ds[1:])
    xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)

    Axs, bxs = get_Abs(Ws, vs, xqs)
    Aqs, bqs = get_Abs(Ws, vs, qs)
    AQs, bQs = get_Abs(Ws, vs, Qs)

    all_bxs = T.hstack(bxs[:-1]).transpose()
    all_Axs = T.hstack(Axs[:-1])[0]

    all_bqs = T.hstack(bqs[:-1]).transpose()
    all_Aqs = T.hstack(Aqs[:-1])[0]

    x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None]
    q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None]

    ############################################################################
    # loss (E-step NLL)
    ############################################################################

    Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0)
    B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0)
    Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1)
    ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1)
    Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]),
                        axis1=1, axis2=2)
    xAm1Bm0 = X * (Am1 + Bm0)

    M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2)
    
    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b
    loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\
            + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\
            + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean())

    mean_loss = - (loss + 0.5 * prior)
    adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs)

    ############################################################################
    # update of var_x
    ############################################################################

    update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\
                    * T.ones(Ds[-1])
    update_varz = M2diag.mean() * T.ones(Ds[0])

    ############################################################################
    # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX
    ############################################################################

    FQ = get_forward(Ws, Qs)
    update_vs = {}
    for i in range(len(vs)):
        
        if i < len(vs) - 1:
            # now we forward each bias to the x-space except the ith
            separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i])
            # compute the residual and apply sigma
            residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\
                                         - T.einsum('nds,Nns->Nnd', AQs[-1], m1)
            back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0))
            whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\
                        + T.eye(back_error.shape[0]) / (Ds[i] * cov_b)
            update_vs[vs[i]] = T.linalg.solve(whiten, back_error)
        else:
            back_error = (X - (Am1 + Bm0) + vs[-1])
            update_vs[vs[i]] = back_error.mean(0)

    ############################################################################
    # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX 
    ############################################################################

    update_Ws = {}
    for i in range(len(Ws)):
        
        U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i])
        if i == 0:
            V = m2.mean(0)
        else:
            V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0)
            V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2)
            V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1)
            Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1])
            V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size

        whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0)
        whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W)
        # compute the residual (bottom up)
        if i == len(Ws) - 1:
            bottom_up = (X[:, None, :] - vs[-1])
        else:
            if i == 0:
                residual = (X[:, None, :] - bQs[-1])
            else:
                residual = (X[:, None, :] - bQs[-1]\
                            + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1]))
            bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual)

        # compute the top down vector
        if i == 0:
            top_down = m1
        else:
            top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\
                               T.einsum('nd,Nn->Nnd', bQs[i - 1], m0))

        vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size
        condition = T.diagonal(whiten)
        update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape)
        update_Ws[Ws[i]] = update_W

    ############################################################################
    # create the io functions
    ############################################################################

    params = sj.function(outputs = Ws + vs + [var_x])
    ll = T.Placeholder((), 'int32')
    selector = T.one_hot(ll, len(vs))
    for i in range(len(vs)):
        update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\
                            * selector[i] + vs[i] * (1 - selector[i])
    for i in range(len(Ws)):
        update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\
                            * selector[i] + Ws[i] * (1 - selector[i])

    output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                  updates=adam.updates),
              'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                        updates = {var_x: update_varx}),
              'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_vs),
              'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_Ws),
              'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]),
              'signs2ineq': sj.function(signs, outputs=q_inequalities),
              'g': sj.function(x, outputs=maps[-1]),
              'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0],
                                       bxs[-1][0], x_inequalities, xsigns]),
              'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph,
                                    updates=dict(zip(Ws + vs + [var_x],
                                             Ws_ph + vs_ph + [var_x_ph]))),
              'varx': sj.function(outputs=var_x),
              'prior': sj.function(outputs=prior),
              'varz': sj.function(outputs=var_z),
              'params': params,
#              'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed),
              'input2signs': sj.function(x, outputs=xsigns),
              'S' : Ds[0], 'D':  Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1,
              'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}}
 
    def sample(n):
        samples = []
        for i in range(n):
            samples.append(output['g'](np.random.randn(Ds[0])))
        return np.array(samples)
    
    output['sample'] = sample

    return output
コード例 #16
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

__author__ = "Randall Balestriero"

import symjax.tensor as T
import matplotlib.pyplot as plt
from symjax.viz import compute_graph

x = T.random.randn((10, ), name="x")
y = T.random.randn((10, ), name="y")
z = T.random.randn((10, ), name="z")

w = T.Variable(T.ones(1), name="w")
out = (x + y).sum() * w + z.sum()

graph = compute_graph(out)
graph.draw("file.png", prog="dot")

import matplotlib.image as mpimg

img = mpimg.imread("file.png")
plt.figure(figsize=(15, 5))
imgplot = plt.imshow(img)
plt.xticks()
plt.yticks()
plt.tight_layout()
コード例 #17
0
ファイル: function.py プロジェクト: brandonwillard/SymJAX
import symjax
import symjax.tensor as T

value = T.Variable(T.ones(()))
randn = T.random.randn(())
rand = T.random.rand(())

out1 = randn * value
out2 = out1.clone({randn: rand})

f = symjax.function(rand, outputs=out2, updates={value: 2 + value})

for i in range(3):
    print(f(i))
# 0.
# 3.
# 10.

# we create a simple computational graph
var = T.Variable(T.random.randn((16, 8), seed=10))
loss = ((var - T.ones_like(var))**2).sum()
g = symjax.gradients(loss, [var])
opt = symjax.optimizers.SGD(loss, 0.01, params=var)

f = symjax.function(outputs=loss, updates=opt.updates)

for i in range(10):
    print(f())
# 240.96829
# 231.42595
# 222.26149
コード例 #18
0
ファイル: wrap_class.py プロジェクト: brandonwillard/SymJAX
__author__ = "Randall Balestriero"


class product:
    def __init__(self, W, V=1):
        self.W = jnp.square(V * W * (W > 0).astype("float32"))
        self.ndim = self.compute_ndim()

    def feed(self, x):
        return jnp.dot(self.W, x)

    def compute_ndim(self):
        return self.W.shape[0] * self.W.shape[1]


wrapped = T.wrap_class(product, method_exceptions=["compute_ndim"])

a = wrapped(T.zeros((10, 10)), V=T.ones((10, 10)))
x = T.random.randn((10, 100))

print(a.W)
# (Tensor: name=function[0], shape=(10, 10), dtype=float32)

print(a.feed(x))
# Op(name=feed, shape=(10, 100), dtype=float32, scope=/)

f = sj.function(outputs=a.feed(x))

f()
コード例 #19
0
ファイル: scan.py プロジェクト: ml-lab/SymJAX
import sys
sys.path.insert(0, "../")
import symjax as sj
import symjax.tensor as T
import numpy as np

__author__ = "Randall Balestriero"

# example of cumulative sum


def func(carry, x):
    return carry + 1, 0


output, _ = T.scan(func, T.zeros(1), T.ones(10), length=10)

f = sj.function(outputs=output)
print(f())
# [10.]

# example of simple RNN

w = T.Placeholder((3, 10), 'float32')
h = T.random.randn((3, 3))
b = T.random.randn((3, ))
t_steps = 100
X = T.random.randn((t_steps, 10))


def rnn_cell(carry, x, w):
コード例 #20
0
ファイル: sgd.py プロジェクト: brandonwillard/SymJAX
losses = list()
values = list()
for i in range(10):
    a, b = train()
    losses.append(a)
    values.append(b)

plt.figure()
plt.subplot(121)
plt.plot(losses)
plt.subplot(122)
plt.plot(values, np.zeros_like(values), "kx")

####### jacobians

x, y = T.ones(()), T.ones(())
print(x, y)
ZZ = T.stack([x, y])
f = T.stack([3 * ZZ[0] + 2 * ZZ[1]], axis=0)
j = symjax.jacobians(f, [ZZ])[0]
g_j = symjax.function(outputs=j)

R = T.random.randn()
f = ZZ * 10 * R
j = symjax.jacobians(f, [ZZ])[0]
g_j = symjax.function(outputs=[j])
for i in range(5):
    print(g_j())

# plt.show()
コード例 #21
0
ファイル: control_flow.py プロジェクト: ml-lab/SymJAX
import jax
import numpy as np
import sys
sys.path.insert(0, "../")

import symjax
import symjax.tensor as T

# map
xx = T.ones(10)
a = T.map(lambda a: a * 2, xx)
g = symjax.gradients(a.sum(), xx)[0]
f = symjax.function(outputs=[a, g])

# scan
xx = T.ones(10) * 2
a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx)
g = symjax.gradients(a[1][-1], xx)[0]
f = symjax.function(outputs=[a, g])

# scan with updates
xx = T.range(5)
uu = T.ones((10, 2))
vvar = T.Variable(T.zeros((10, 2)))
vv = T.index_add(vvar, 1, 1)
a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv])
#a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx)
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx)
#g = symjax.gradients(a[1][-1],xx)
f = symjax.function(outputs=a[0], updates={vvar: vvar + 1})
print(f(), f(), f())
コード例 #22
0
import jax
import networkx as nx
import matplotlib.pyplot as plt

bin_rules = [
    [
        lambda *args: args[0] == 0 or args[1] == 0,
        jax.numpy.add,
        lambda *args: args[0] if args[1] == 0 else args[0],
    ],
    [lambda *args: len(args) == 1, jax.numpy.add, lambda *args: 2 * args[0]],
    [lambda *args: args[1] == 1, jax.numpy.true_divide, lambda *args: args[0]],
    [
        lambda *args: len(args) == 1,
        jax.numpy.true_divide,
        lambda *args: T.ones(args[0].shape, args[0].dtype),
    ],
    [
        lambda *args: args[0] == 1 or args[1] == 1,
        jax.numpy.multiply,
        lambda *args: args[0] if args[1] == 1 else args[1],
    ],
]


def simplify_add(graph):
    to_search = list(graph.nodes.keys())
    while len(to_search):
        j = to_search[-1]
        if type(j) == T.Op:
コード例 #23
0
ファイル: graph.py プロジェクト: brandonwillard/SymJAX
import symjax
import symjax.tensor as T

g = symjax.Graph("model1")
with g:
    learning_rate = T.Variable(T.ones((1, )))
    with symjax.Graph("layer1"):
        W1 = T.Variable(T.zeros((1, )), name="W")
        b1 = T.Variable(T.zeros((1, )), name="b")
    with symjax.Graph("layer2"):
        W2 = T.Variable(T.zeros((1, )), name="W")
        b2 = T.Variable(T.zeros((1, )), name="b")

# define an irrelevant loss function involving the parameters
loss = (W1 + b1 + W2 + b2) * learning_rate

# and a train/update function
train = symjax.function(outputs=loss,
                        updates={
                            W1: W1 + 1,
                            b1: b1 + 2,
                            W2: W2 + 2,
                            b2: b2 + 3
                        })

# pretend we train for a while
for i in range(4):
    print(train())

# [0.]
# [8.]
コード例 #24
0
def create_fns(input,
               in_signs,
               Ds,
               x,
               m0,
               m1,
               m2,
               batch_in_signs,
               alpha=0.1,
               sigma=1,
               sigma_x=1,
               lr=0.0002):

    cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])])
    BS = batch_in_signs.shape[0]
    Ws = [
        T.Variable(sj.initializers.glorot((j, i)) * sigma)
        for j, i in zip(Ds[1:], Ds[:-1])
    ]
    bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\
                + [T.Variable(T.zeros((Ds[-1],)))]

    A_w = [T.eye(Ds[0])]
    B_w = [T.zeros(Ds[0])]

    A_q = [T.eye(Ds[0])]
    B_q = [T.zeros(Ds[0])]

    batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))]
    batch_B_q = [T.zeros((BS, Ds[0]))]

    maps = [input]
    signs = []
    masks = [T.ones(Ds[0])]

    in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha)
    batch_in_masks = T.where(
        T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha)

    for w, b in zip(Ws[:-1], bs[:-1]):

        pre_activation = T.matmul(w, maps[-1]) + b
        signs.append(T.sign(pre_activation))
        masks.append(T.where(pre_activation > 0, 1., alpha))

        maps.append(pre_activation * masks[-1])

    maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1])

    # compute per region A and B
    for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:],
                                   Ws, bs, masks):

        A_w.append(T.matmul(w * m, A_w[-1]))
        B_w.append(T.matmul(w * m, B_w[-1]) + b)

        A_q.append(T.matmul(w * in_masks[start:end], A_q[-1]))
        B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b)

        batch_A_q.append(
            T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1]))
        batch_B_q.append((w * batch_in_masks[:, None, start:end]\
                            * batch_B_q[-1][:, None, :]).sum(2) + b)

    batch_B_q = batch_B_q[-1]
    batch_A_q = batch_A_q[-1]

    signs = T.concatenate(signs)

    inequalities = T.hstack(
        [T.concatenate(B_w[1:-1])[:, None],
         T.vstack(A_w[1:-1])]) * signs[:, None]

    inequalities_code = T.hstack(
        [T.concatenate(B_q[1:-1])[:, None],
         T.vstack(A_q[1:-1])]) * in_signs[:, None]

    #### loss
    log_sigma2 = T.Variable(sigma_x)
    sigma2 = T.exp(log_sigma2)

    Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1)
    Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0)
    B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2)
    AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2)

    inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2))

    loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze()

    loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze()

    cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi)

    loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\
            + 0.5 * loss_2 / sigma2 + 0.5 * loss_z

    mean_loss = loss.mean()
    adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9)

    train_f = sj.function(batch_in_signs,
                          x,
                          m0,
                          m1,
                          m2,
                          outputs=mean_loss,
                          updates=adam.updates)
    f = sj.function(input,
                    outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs])
    g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]])
    all_g = sj.function(in_signs, outputs=inequalities_code)
    h = sj.function(input, outputs=maps[-1])
    return f, g, h, all_g, train_f, sigma2
コード例 #25
0
import symjax
import symjax.tensor as T

# scope/graph naming and accessing

value1 = T.Variable(T.ones((1,)))
value2 = T.Variable(T.zeros((1,)))

g = symjax.Graph("special")
with g:
    value3 = T.Variable(T.zeros((1,)))
    value4 = T.Variable(T.zeros((1,)))
    result = value3 + value4

    h = symjax.Graph("inversion")
    with h:
        value5 = T.Variable(T.zeros((1,)))
        value6 = T.Variable(T.zeros((1,)))
        value7 = T.Variable(T.zeros((1,)), name="w")


print(g.variables)
# {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/),
#  'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/)}

print(h.variables)
# {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# 'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# 'w': Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)}

print(h.variable("w"))