Esempio n. 1
0
    def forward(self, input, crop_shape, deterministic, padding=0, seed=None):

        self.crop_shape = crop_shape
        # if given only a scalar
        if not hasattr(padding, "__len__"):
            self.pad_shape = [(padding, padding)] * (input.shape - 1)
        # else
        else:
            self.pad_shape = [(pad,
                               pad) if not hasattr(pad, "__len__") else pad
                              for pad in padding]

        assert len(self.pad_shape) == len(self.crop_shape)
        assert len(self.pad_shape) == (len(input.shape) - 1)

        self.start_indices = list()
        self.fixed_indices = list()
        for i, (pad, dim, crop) in enumerate(
                zip(self.pad_shape, input.shape[1:], self.crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            assert maxval >= 0
            self.start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            self.fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        self.start_indices = T.concatenate(self.start_indices, 1)
        self.fixed_indices = T.concatenate(self.fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + self.pad_shape)

        routput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.start_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )
        doutput = T.stack(
            [
                T.dynamic_slice(pinput[n], self.fixed_indices[n],
                                self.crop_shape) for n in range(input.shape[0])
            ],
            0,
        )

        return doutput * dirac + (1 - dirac) * routput
Esempio n. 2
0
 def gate(
     carry,
     x,
     Wf,
     Uf,
     bf,
     Wi,
     Ui,
     bi,
     Wo,
     Uo,
     bo,
     Wc,
     Uc,
     bc,
     sigma_g,
     sigma_c,
     sigma_h,
 ):
     h, c = carry[0], carry[1]
     f = sigma_g(T.dot(x, Wf) + bf + T.dot(h, Uf))
     i = sigma_g(T.dot(x, Wi) + bi + T.dot(h, Ui))
     o = sigma_g(T.dot(x, Wo) + bo + T.dot(h, Uo))
     ctilde = sigma_c(T.dot(x, Wc) + bc + T.dot(h, Uc))
     cnew = f * c + i * ctilde
     hnew = o * sigma_h(cnew)
     return T.stack([hnew, cnew]), h
Esempio n. 3
0
def test_stack():
    u = tt.Variable(tt.ones((2, )))
    output = tt.stack([u, 2 * u, 3 * u])
    f = symjax.function(outputs=output)
    assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2)))
    print(f())
    print(f())
Esempio n. 4
0
    def RNTK_middle(self, previous,x): # line 7, alg 1
        X = x * x[:, None] # <x, x^t>
        rntk_old = previous[0]
        gp_old = previous[1]
        S_old,D_old = self.VT(gp_old[0])#.  //vv K(1,t-1)
        gp_new = T.expand_dims(self.sw ** 2 * S_old + (self.su ** 2) * X + self.sb ** 2,axis = 0) # line 8, alg 1
        if self.Lf == 0: # if none of the katers are fixed, use the standard
            rntk_new = T.expand_dims(gp_new[0] + self.sw**2*rntk_old[0]*D_old,axis = 0)
        else:
            rntk_new = T.expand_dims(gp_new[0],axis = 0) 
        
        print("gp_new 3", gp_new)
        print("rntk_new 3", rntk_new)

        for l in range(self.L-1): #line 10
            l = l+1
            S_new,D_new = self.VT(gp_new[l-1]) # l-1, t
            S_old,D_old = self.VT(gp_old[l]) # t-1, l
            gp_new = T.concatenate( [gp_new, T.expand_dims( self.sw ** 2 * S_old + self.su ** 2 * S_new +  self.sb ** 2, axis = 0)]) #line 10
            rntk_new = T.concatenate( [ rntk_new,  T.expand_dims( gp_new[l] +(self.Lf <= l)*self.sw**2*rntk_old[l]*D_old +(self.Lf <= (l-1))* self.su**2*rntk_new[l-1]*D_new  ,axis = 0)   ]  )
        S_old,D_old = self.VT(gp_new[self.L-1])
        gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) # line 11
        rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[self.L]+ gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)])
        print("gp_new 4", gp_new)
        print("rntk_new 4", rntk_new)
        return T.stack([rntk_new,gp_new]),x
Esempio n. 5
0
    def forward(self, input, deterministic=None):

        if deterministic is None:
            deterministic = self.deterministic
        dirac = T.cast(deterministic, 'float32')

        # pad the input
        pinput = T.pad(input, [(0, 0)] + self.pad_shape)

        routput = T.stack([
            T.dynamic_slice(pinput[n], self.start_indices[n], self.crop_shape)
            for n in range(self.input.shape[0])
        ], 0)
        doutput = T.stack([
            T.dynamic_slice(pinput[n], self.fixed_indices[n], self.crop_shape)
            for n in range(self.input.shape[0])
        ], 0)

        return doutput * dirac + (1 - dirac) * routput
Esempio n. 6
0
def PiecewiseConstant(init, steps_and_values):

    with Scope("PiecewiseConstant"):

        all_steps = T.stack([0] + list(steps_and_values.keys()))
        all_values = T.stack([init] + list(steps_and_values.values()))

        step = T.Variable(
            T.zeros(1),
            trainable=False,
            name="step",
            dtype="float32",
        )

        value = all_values[(step < all_steps).argmin() - 1]

        current_graph().add({step: step + 1})

    return value
Esempio n. 7
0
def leaky_swish(x, beta, negative_slope=1e-2):
    r"""Swish activation function associated to leaky relu.

    Computes the element-wise function:

    .. math::
      \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-\beta * x}}
    """

    feature = T.stack([negative_slope * x, x], -1)
    return (feature * softmax(feature * beta)).sum(-1)
Esempio n. 8
0
    def build_net(self, Q):
        # ------------------ all inputs ------------------------
        state = T.Placeholder([self.batch_size, self.n_states],
                              "float32",
                              name="s")
        next_state = T.Placeholder([self.batch_size, self.n_states],
                                   "float32",
                                   name="s_")
        reward = T.Placeholder(
            [
                self.batch_size,
            ],
            "float32",
            name="r",
        )  # input reward
        action = T.Placeholder(
            [
                self.batch_size,
            ],
            "int32",
            name="a",
        )  # input Action

        with symjax.Scope("eval_net"):
            q_eval = Q(state, self.n_actions)
        with symjax.Scope("test_set"):
            q_next = Q(next_state, self.n_actions)

        q_target = reward + self.reward_decay * q_next.max(1)
        q_target = T.stop_gradient(q_target)

        a_indices = T.stack([T.range(self.batch_size), action], axis=1)
        q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)),
                                         1).squeeze(1)
        loss = T.mean((q_target - q_eval_wrt_a)**2)
        nn.optimizers.Adam(loss, self.lr)

        self.train = symjax.function(state,
                                     action,
                                     reward,
                                     next_state,
                                     updates=symjax.get_updates())
        self.q_eval = symjax.function(state, outputs=q_eval)
Esempio n. 9
0
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1):

    # gaussian parameters
    freqs = get_scaled_freqs(f0, f1, J)
    freqs *= (J - 1) * 10

    if modes > 1:
        other_modes = np.random.randint(0, J, J * (modes - 1))
        freqs = np.concatenate([freqs, freqs[other_modes]])

    # crate the average vectors
    mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1)
    mu = T.Variable(mu_init.astype('float32'), name='mu')

    # create the covariance matrix
    cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'),
                     name='cor')

    sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)],
                          1)
    sigma = T.Variable(sigma_init.astype('float32'), name='sigma')

    # create the mixing coefficients
    mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32'))

    # now apply our parametrization
    coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95
    Id = T.eye(2)
    cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\
            T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1))
    cov_inv = T.linalg.inv(cov)

    # get the gaussian filters
    time = T.linspace(-5, 5, M)
    freq = T.linspace(0, J * 10, N)
    x, y = T.meshgrid(time, freq)
    grid = T.stack([y.flatten(), x.flatten()], 1)
    centered = grid - T.expand_dims(mu, 1)
    # asdf
    gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1))
    norm = T.linalg.norm(gaussian, 2, 1, keepdims=True)
    gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M))
    return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
Esempio n. 10
0
def RNTK_middle(previous,x,sw,su,sb,L,Lf,sv):
    X = x * x[:, None]
    rntk_old = previous[0]
    gp_old = previous[1]
    S_old,D_old = VT(gp_old[0])
    gp_new = T.expand_dims(sw ** 2 * S_old + (su ** 2) * X + sb ** 2,axis = 0)
    if Lf == 0:
        rntk_new = T.expand_dims(gp_new[0] + sw**2*rntk_old[0]*D_old,axis = 0)
    else:
        rntk_new = T.expand_dims(gp_new[0],axis = 0)
    for l in range(L-1):
        l = l+1
        S_new,D_new = VT(gp_new[l-1])
        S_old,D_old = VT(gp_old[l])
        gp_new = T.concatenate( [gp_new, T.expand_dims( sw ** 2 * S_old + su ** 2 * S_new +  sb ** 2, axis = 0)])
        rntk_new = T.concatenate( [ rntk_new,  T.expand_dims( gp_new[l] +(Lf <= l)*sw**2*rntk_old[l]*D_old +(Lf <= (l-1))* su**2*rntk_new[l-1]*D_new  ,axis = 0)   ]  )
    S_old,D_old = VT(gp_new[L-1])
    gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)])
    rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[L]+ gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)])
    return T.stack([rntk_new,gp_new]),x
Esempio n. 11
0
 def create_network(self, state):
     state = T.stack([T.arccos(state[:, 0]), state[:, 2]], 1)
     input = nn.relu(nn.layers.Dense(state, 32))
     input = nn.relu(nn.layers.Dense(input, 32))
     input = nn.layers.Dense(input, 1)
     return input
Esempio n. 12
0
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"

symjax.current_graph().reset()


mnist = symjax.data.mnist()
# 2d image
images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0]
images /= images.max()

np.random.seed(0)

coordinates = T.meshgrid(T.range(28), T.range(28))
coordinates = T.Variable(
    T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32")
)
interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape(
    (28, 28)
)

loss = ((interp - images[1]) ** 2).mean()

lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005})
symjax.nn.optimizers.Adam(loss, lr)

train = symjax.function(outputs=loss, updates=symjax.get_updates())

rec = symjax.function(outputs=interp)

losses = list()
Esempio n. 13
0
    def __init__(
        self,
        sequence,
        init_h,
        units,
        Wf=initializers.glorot_uniform,
        Uf=initializers.orthogonal,
        bf=T.zeros,
        Wi=initializers.glorot_uniform,
        Ui=initializers.orthogonal,
        bi=T.zeros,
        Wo=initializers.glorot_uniform,
        Uo=initializers.orthogonal,
        bo=T.zeros,
        Wc=initializers.glorot_uniform,
        Uc=initializers.orthogonal,
        bc=T.zeros,
        trainable_Wf=True,
        trainable_Uf=True,
        trainable_bf=True,
        trainable_Wi=True,
        trainable_Ui=True,
        trainable_bi=True,
        trainable_Wo=True,
        trainable_Uo=True,
        trainable_bo=True,
        trainable_Wc=True,
        trainable_Uc=True,
        trainable_bc=True,
        activation_g=nn.sigmoid,
        activation_c=T.tanh,
        activation_h=T.tanh,
        only_last=False,
        gate="minimal",
    ):

        self.create_variable("Wf",
                             Wf, (sequence.shape[2], units),
                             trainable=trainable_Wf)
        self.create_variable("Uf", Uf, (units, units), trainable=trainable_Uf)
        self.create_variable("bf", bf, (units, ), trainable=trainable_bf)

        self.create_variable("Wi",
                             Wi, (sequence.shape[2], units),
                             trainable=trainable_Wi)
        self.create_variable("Ui", Ui, (units, units), trainable=trainable_Ui)
        self.create_variable("bi", bi, (units, ), trainable=trainable_bi)

        self.create_variable("Wo",
                             Wo, (sequence.shape[2], units),
                             trainable=trainable_Wo)
        self.create_variable("Uo", Uo, (units, units), trainable=trainable_Uo)
        self.create_variable("bo", bo, (units, ), trainable=trainable_bo)

        self.create_variable("Wc",
                             Wc, (sequence.shape[2], units),
                             trainable=trainable_Wc)
        self.create_variable("Uc", Uc, (units, units), trainable=trainable_Uc)
        self.create_variable("bc", bc, (units, ), trainable=trainable_bc)

        def fn(*args):
            return self.gate(*args, activation_g, activation_c, activation_h)

        init = T.stack((init_h, T.zeros(init_h.shape, init_h.dtype)))
        last, output = T.scan(
            fn,
            init=init,
            sequences=[sequence.transpose((1, 0, 2))],
            non_sequences=[
                self.Wf,
                self.Uf,
                self.bf,
                self.Wi,
                self.Ui,
                self.bi,
                self.Wo,
                self.Uo,
                self.bo,
                self.Wc,
                self.Uc,
                self.bc,
            ],
        )

        if only_last:
            return last
        else:
            return output.transpose((1, 0, 2))
Esempio n. 14
0
import sys

sys.path.insert(0, "../")

import symjax as sj
import symjax.tensor as T
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use('Agg')

###### 2D GAUSSIAN EXAMPLE

t = T.linspace(-5, 5, 5)
x, y = T.meshgrid(t, t)
X = T.stack([x.flatten(), y.flatten()], 1)
p = T.pdfs.multivariate_normal.pdf(X, T.zeros(2), T.eye(2))
p = p.reshape((5, 5)).round(2)

print(p)
# Tensor(Op=round_, shape=(5, 5), dtype=float32)

# lazy evaluation (not compiled nor optimized)
print(p.get())
# [[0.   0.   0.   0.   0.  ]
#  [0.   0.   0.01 0.   0.  ]
#  [0.   0.01 0.16 0.01 0.  ]
#  [0.   0.   0.01 0.   0.  ]
#  [0.   0.   0.   0.   0.  ]]

# create the function which internall compiles and optimizes
Esempio n. 15
0
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1,
               var_x=1):

    alpha = T.Placeholder((1,), 'float32')
    x = T.Placeholder((Ds[0],), 'float32')
    X = T.Placeholder((batch_size, Ds[-1]), 'float32')

    signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32')
    SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32')
    
    m0 = T.Placeholder((batch_size, R), 'float32')
    m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32')
    m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32')

    Ws, vs = init_weights(Ds, seed, scaler)
    Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)]
    vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)]

    var_x = T.Variable(T.ones(Ds[-1]) * var_x)
    var_z = T.Variable(T.ones(Ds[0]))

    # create the placeholders
    Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws]
    vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs]
    var_x_ph = T.Placeholder(var_x.shape, var_x.dtype)

    ############################################################################
    # Compute the output of g(x)
    ############################################################################

    maps = [x]
    xsigns = []
    masks = []
    
    for w, v in zip(Ws[:-1], vs[:-1]):
        
        pre_activation = T.matmul(w, maps[-1]) + v
        xsigns.append(T.sign(pre_activation))
        masks.append(relu_mask(pre_activation, leakiness))
        maps.append(pre_activation * masks[-1])

    xsigns = T.concatenate(xsigns)
    maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1])

    ############################################################################
    # compute the masks and then the per layer affine mappings
    ############################################################################

    cumulative_units = np.cumsum([0] + Ds[1:])
    xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)
    Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]]
                    for i in range(len(Ds) - 2)], leakiness)

    Axs, bxs = get_Abs(Ws, vs, xqs)
    Aqs, bqs = get_Abs(Ws, vs, qs)
    AQs, bQs = get_Abs(Ws, vs, Qs)

    all_bxs = T.hstack(bxs[:-1]).transpose()
    all_Axs = T.hstack(Axs[:-1])[0]

    all_bqs = T.hstack(bqs[:-1]).transpose()
    all_Aqs = T.hstack(Aqs[:-1])[0]

    x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None]
    q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None]

    ############################################################################
    # loss (E-step NLL)
    ############################################################################

    Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0)
    B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0)
    Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1)
    ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1)
    Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]),
                        axis1=1, axis2=2)
    xAm1Bm0 = X * (Am1 + Bm0)

    M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2)
    
    prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\
            + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b
    loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\
            + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\
            + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean())

    mean_loss = - (loss + 0.5 * prior)
    adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs)

    ############################################################################
    # update of var_x
    ############################################################################

    update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\
                    * T.ones(Ds[-1])
    update_varz = M2diag.mean() * T.ones(Ds[0])

    ############################################################################
    # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX
    ############################################################################

    FQ = get_forward(Ws, Qs)
    update_vs = {}
    for i in range(len(vs)):
        
        if i < len(vs) - 1:
            # now we forward each bias to the x-space except the ith
            separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i])
            # compute the residual and apply sigma
            residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\
                                         - T.einsum('nds,Nns->Nnd', AQs[-1], m1)
            back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0))
            whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\
                        + T.eye(back_error.shape[0]) / (Ds[i] * cov_b)
            update_vs[vs[i]] = T.linalg.solve(whiten, back_error)
        else:
            back_error = (X - (Am1 + Bm0) + vs[-1])
            update_vs[vs[i]] = back_error.mean(0)

    ############################################################################
    # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX 
    ############################################################################

    update_Ws = {}
    for i in range(len(Ws)):
        
        U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i])
        if i == 0:
            V = m2.mean(0)
        else:
            V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0)
            V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2)
            V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1)
            Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1])
            V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size

        whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0)
        whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W)
        # compute the residual (bottom up)
        if i == len(Ws) - 1:
            bottom_up = (X[:, None, :] - vs[-1])
        else:
            if i == 0:
                residual = (X[:, None, :] - bQs[-1])
            else:
                residual = (X[:, None, :] - bQs[-1]\
                            + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1]))
            bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual)

        # compute the top down vector
        if i == 0:
            top_down = m1
        else:
            top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\
                               T.einsum('nd,Nn->Nnd', bQs[i - 1], m0))

        vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size
        condition = T.diagonal(whiten)
        update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape)
        update_Ws[Ws[i]] = update_W

    ############################################################################
    # create the io functions
    ############################################################################

    params = sj.function(outputs = Ws + vs + [var_x])
    ll = T.Placeholder((), 'int32')
    selector = T.one_hot(ll, len(vs))
    for i in range(len(vs)):
        update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\
                            * selector[i] + vs[i] * (1 - selector[i])
    for i in range(len(Ws)):
        update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\
                            * selector[i] + Ws[i] * (1 - selector[i])

    output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                  updates=adam.updates),
              'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                        updates = {var_x: update_varx}),
              'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_vs),
              'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss,
                                      updates = update_Ws),
              'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]),
              'signs2ineq': sj.function(signs, outputs=q_inequalities),
              'g': sj.function(x, outputs=maps[-1]),
              'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0],
                                       bxs[-1][0], x_inequalities, xsigns]),
              'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss),
              'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph,
                                    updates=dict(zip(Ws + vs + [var_x],
                                             Ws_ph + vs_ph + [var_x_ph]))),
              'varx': sj.function(outputs=var_x),
              'prior': sj.function(outputs=prior),
              'varz': sj.function(outputs=var_z),
              'params': params,
#              'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed),
              'input2signs': sj.function(x, outputs=xsigns),
              'S' : Ds[0], 'D':  Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1,
              'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed,
                    'leakiness':leakiness, 'lr':lr, 'scaler':scaler}}
 
    def sample(n):
        samples = []
        for i in range(n):
            samples.append(output['g'](np.random.randn(Ds[0])))
        return np.array(samples)
    
    output['sample'] = sample

    return output
Esempio n. 16
0
def RNTK_function(N,length,param):
    DATA = T.Placeholder((N, length), 'float32')
    RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav'])
    v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav']
                                         ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
    RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length)
    f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
    return f
Esempio n. 17
0
def PiecewiseConstant(init, steps_and_values):
    """piecewise constant variable updating automatically

    This method allows to obtain a variable with an internal counter
    that will be updated based on the function updates, whenver this
    counter reaches one of the step given in the function input
    then the actual value of the variable becomes the one given for the
    associated step

    Args
    ----

    init: float-like
        the initial value of the variable that will remain as is until a step
        and update is reached

    steps_and_values: dict
        the dictionnary mapping steps-> values, that is, when the number of
        steps reached one of the given one, the value of the variable becomes
        the given one associated to the reached step

    Returns
    -------

    variable: float-like

    Example
    -------

    .. doctest ::
    >>> import symjax
    >>> symjax.current_graph().reset()
    >>> var = symjax.nn.schedules.PiecewiseConstant(0.1, {4:1, 8:2})
    >>> # in the background, symjax automatically records that everytime
    >>> # a function is using this variable an udnerlying update should occur
    >>> print(symjax.get_updates())
    {Variable(name=step, shape=(), dtype=int32, trainable=False, scope=/PiecewiseConstant/): Op(name=add, fn=add, shape=(), dtype=int32, scope=/PiecewiseConstant/)}
    >>> # it is up to the user to use it or not, if not used, the internal counter
    >>> # is never updated and this the variable never changes.
    >>> # example of use:
    >>> f = symjax.function(outputs=var, updates=symjax.get_updates())
    >>> for i in range(10):
    ...     print(i, f())
    0 0.1
    1 0.1
    2 0.1
    3 0.1
    4 1.0
    5 1.0
    6 1.0
    7 1.0
    8 2.0
    9 2.0

    """

    with Scope("PiecewiseConstant"):

        all_steps = T.stack([0] + list(steps_and_values.keys()) + [np.inf])
        all_values = T.stack([init] + list(steps_and_values.values()) + [0])

        step = T.Variable(
            0,
            trainable=False,
            name="step",
            dtype="int32",
        )

        value = all_values[(step >= all_steps).argmin() - 1]

    return value, step
Esempio n. 18
0
 def RNTK_function(self):
     print(f"N, {self.N}, length, {self.length}")
     DATA = T.Placeholder((self.N, self.length), 'float32')
     RNTK,GP = self.RNTK_first(DATA[:,0])
     v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP]))
     RNTK_last,RNTK_avg = self.RNTK_output(v)
     f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg])
     return RNTK_last,RNTK_avg
Esempio n. 19
0
losses = list()
values = list()
for i in range(10):
    a, b = train()
    losses.append(a)
    values.append(b)

plt.figure()
plt.subplot(121)
plt.plot(losses)
plt.subplot(122)
plt.plot(values, np.zeros_like(values), "kx")

####### jacobians

x, y = T.ones(()), T.ones(())
print(x, y)
ZZ = T.stack([x, y])
f = T.stack([3 * ZZ[0] + 2 * ZZ[1]], axis=0)
j = symjax.jacobians(f, [ZZ])[0]
g_j = symjax.function(outputs=j)

R = T.random.randn()
f = ZZ * 10 * R
j = symjax.jacobians(f, [ZZ])[0]
g_j = symjax.function(outputs=[j])
for i in range(5):
    print(g_j())

# plt.show()