def forward(self, input, crop_shape, deterministic, padding=0, seed=None): self.crop_shape = crop_shape # if given only a scalar if not hasattr(padding, "__len__"): self.pad_shape = [(padding, padding)] * (input.shape - 1) # else else: self.pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad for pad in padding] assert len(self.pad_shape) == len(self.crop_shape) assert len(self.pad_shape) == (len(input.shape) - 1) self.start_indices = list() self.fixed_indices = list() for i, (pad, dim, crop) in enumerate( zip(self.pad_shape, input.shape[1:], self.crop_shape)): maxval = pad[0] + pad[1] + dim - crop assert maxval >= 0 self.start_indices.append( T.random.randint( minval=0, maxval=maxval, shape=(input.shape[0], 1), dtype="int32", seed=seed + i if seed is not None else seed, )) self.fixed_indices.append( T.ones((input.shape[0], 1), "int32") * (maxval // 2)) self.start_indices = T.concatenate(self.start_indices, 1) self.fixed_indices = T.concatenate(self.fixed_indices, 1) dirac = T.cast(deterministic, "float32") # pad the input pinput = T.pad(input, [(0, 0)] + self.pad_shape) routput = T.stack( [ T.dynamic_slice(pinput[n], self.start_indices[n], self.crop_shape) for n in range(input.shape[0]) ], 0, ) doutput = T.stack( [ T.dynamic_slice(pinput[n], self.fixed_indices[n], self.crop_shape) for n in range(input.shape[0]) ], 0, ) return doutput * dirac + (1 - dirac) * routput
def gate( carry, x, Wf, Uf, bf, Wi, Ui, bi, Wo, Uo, bo, Wc, Uc, bc, sigma_g, sigma_c, sigma_h, ): h, c = carry[0], carry[1] f = sigma_g(T.dot(x, Wf) + bf + T.dot(h, Uf)) i = sigma_g(T.dot(x, Wi) + bi + T.dot(h, Ui)) o = sigma_g(T.dot(x, Wo) + bo + T.dot(h, Uo)) ctilde = sigma_c(T.dot(x, Wc) + bc + T.dot(h, Uc)) cnew = f * c + i * ctilde hnew = o * sigma_h(cnew) return T.stack([hnew, cnew]), h
def test_stack(): u = tt.Variable(tt.ones((2, ))) output = tt.stack([u, 2 * u, 3 * u]) f = symjax.function(outputs=output) assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2))) print(f()) print(f())
def RNTK_middle(self, previous,x): # line 7, alg 1 X = x * x[:, None] # <x, x^t> rntk_old = previous[0] gp_old = previous[1] S_old,D_old = self.VT(gp_old[0])#. //vv K(1,t-1) gp_new = T.expand_dims(self.sw ** 2 * S_old + (self.su ** 2) * X + self.sb ** 2,axis = 0) # line 8, alg 1 if self.Lf == 0: # if none of the katers are fixed, use the standard rntk_new = T.expand_dims(gp_new[0] + self.sw**2*rntk_old[0]*D_old,axis = 0) else: rntk_new = T.expand_dims(gp_new[0],axis = 0) print("gp_new 3", gp_new) print("rntk_new 3", rntk_new) for l in range(self.L-1): #line 10 l = l+1 S_new,D_new = self.VT(gp_new[l-1]) # l-1, t S_old,D_old = self.VT(gp_old[l]) # t-1, l gp_new = T.concatenate( [gp_new, T.expand_dims( self.sw ** 2 * S_old + self.su ** 2 * S_new + self.sb ** 2, axis = 0)]) #line 10 rntk_new = T.concatenate( [ rntk_new, T.expand_dims( gp_new[l] +(self.Lf <= l)*self.sw**2*rntk_old[l]*D_old +(self.Lf <= (l-1))* self.su**2*rntk_new[l-1]*D_new ,axis = 0) ] ) S_old,D_old = self.VT(gp_new[self.L-1]) gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) # line 11 rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[self.L]+ gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)]) print("gp_new 4", gp_new) print("rntk_new 4", rntk_new) return T.stack([rntk_new,gp_new]),x
def forward(self, input, deterministic=None): if deterministic is None: deterministic = self.deterministic dirac = T.cast(deterministic, 'float32') # pad the input pinput = T.pad(input, [(0, 0)] + self.pad_shape) routput = T.stack([ T.dynamic_slice(pinput[n], self.start_indices[n], self.crop_shape) for n in range(self.input.shape[0]) ], 0) doutput = T.stack([ T.dynamic_slice(pinput[n], self.fixed_indices[n], self.crop_shape) for n in range(self.input.shape[0]) ], 0) return doutput * dirac + (1 - dirac) * routput
def PiecewiseConstant(init, steps_and_values): with Scope("PiecewiseConstant"): all_steps = T.stack([0] + list(steps_and_values.keys())) all_values = T.stack([init] + list(steps_and_values.values())) step = T.Variable( T.zeros(1), trainable=False, name="step", dtype="float32", ) value = all_values[(step < all_steps).argmin() - 1] current_graph().add({step: step + 1}) return value
def leaky_swish(x, beta, negative_slope=1e-2): r"""Swish activation function associated to leaky relu. Computes the element-wise function: .. math:: \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-\beta * x}} """ feature = T.stack([negative_slope * x, x], -1) return (feature * softmax(feature * beta)).sum(-1)
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1): # gaussian parameters freqs = get_scaled_freqs(f0, f1, J) freqs *= (J - 1) * 10 if modes > 1: other_modes = np.random.randint(0, J, J * (modes - 1)) freqs = np.concatenate([freqs, freqs[other_modes]]) # crate the average vectors mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1) mu = T.Variable(mu_init.astype('float32'), name='mu') # create the covariance matrix cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'), name='cor') sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)], 1) sigma = T.Variable(sigma_init.astype('float32'), name='sigma') # create the mixing coefficients mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32')) # now apply our parametrization coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95 Id = T.eye(2) cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\ T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1)) cov_inv = T.linalg.inv(cov) # get the gaussian filters time = T.linspace(-5, 5, M) freq = T.linspace(0, J * 10, N) x, y = T.meshgrid(time, freq) grid = T.stack([y.flatten(), x.flatten()], 1) centered = grid - T.expand_dims(mu, 1) # asdf gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1)) norm = T.linalg.norm(gaussian, 2, 1, keepdims=True) gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M)) return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
def RNTK_middle(previous,x,sw,su,sb,L,Lf,sv): X = x * x[:, None] rntk_old = previous[0] gp_old = previous[1] S_old,D_old = VT(gp_old[0]) gp_new = T.expand_dims(sw ** 2 * S_old + (su ** 2) * X + sb ** 2,axis = 0) if Lf == 0: rntk_new = T.expand_dims(gp_new[0] + sw**2*rntk_old[0]*D_old,axis = 0) else: rntk_new = T.expand_dims(gp_new[0],axis = 0) for l in range(L-1): l = l+1 S_new,D_new = VT(gp_new[l-1]) S_old,D_old = VT(gp_old[l]) gp_new = T.concatenate( [gp_new, T.expand_dims( sw ** 2 * S_old + su ** 2 * S_new + sb ** 2, axis = 0)]) rntk_new = T.concatenate( [ rntk_new, T.expand_dims( gp_new[l] +(Lf <= l)*sw**2*rntk_old[l]*D_old +(Lf <= (l-1))* su**2*rntk_new[l-1]*D_new ,axis = 0) ] ) S_old,D_old = VT(gp_new[L-1]) gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)]) rntk_new = T.concatenate([rntk_new,T.expand_dims(rntk_old[L]+ gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)]) return T.stack([rntk_new,gp_new]),x
def create_network(self, state): state = T.stack([T.arccos(state[:, 0]), state[:, 2]], 1) input = nn.relu(nn.layers.Dense(state, 32)) input = nn.relu(nn.layers.Dense(input, 32)) input = nn.layers.Dense(input, 1) return input
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/" symjax.current_graph().reset() mnist = symjax.data.mnist() # 2d image images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0] images /= images.max() np.random.seed(0) coordinates = T.meshgrid(T.range(28), T.range(28)) coordinates = T.Variable( T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32") ) interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape( (28, 28) ) loss = ((interp - images[1]) ** 2).mean() lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005}) symjax.nn.optimizers.Adam(loss, lr) train = symjax.function(outputs=loss, updates=symjax.get_updates()) rec = symjax.function(outputs=interp) losses = list()
def __init__( self, sequence, init_h, units, Wf=initializers.glorot_uniform, Uf=initializers.orthogonal, bf=T.zeros, Wi=initializers.glorot_uniform, Ui=initializers.orthogonal, bi=T.zeros, Wo=initializers.glorot_uniform, Uo=initializers.orthogonal, bo=T.zeros, Wc=initializers.glorot_uniform, Uc=initializers.orthogonal, bc=T.zeros, trainable_Wf=True, trainable_Uf=True, trainable_bf=True, trainable_Wi=True, trainable_Ui=True, trainable_bi=True, trainable_Wo=True, trainable_Uo=True, trainable_bo=True, trainable_Wc=True, trainable_Uc=True, trainable_bc=True, activation_g=nn.sigmoid, activation_c=T.tanh, activation_h=T.tanh, only_last=False, gate="minimal", ): self.create_variable("Wf", Wf, (sequence.shape[2], units), trainable=trainable_Wf) self.create_variable("Uf", Uf, (units, units), trainable=trainable_Uf) self.create_variable("bf", bf, (units, ), trainable=trainable_bf) self.create_variable("Wi", Wi, (sequence.shape[2], units), trainable=trainable_Wi) self.create_variable("Ui", Ui, (units, units), trainable=trainable_Ui) self.create_variable("bi", bi, (units, ), trainable=trainable_bi) self.create_variable("Wo", Wo, (sequence.shape[2], units), trainable=trainable_Wo) self.create_variable("Uo", Uo, (units, units), trainable=trainable_Uo) self.create_variable("bo", bo, (units, ), trainable=trainable_bo) self.create_variable("Wc", Wc, (sequence.shape[2], units), trainable=trainable_Wc) self.create_variable("Uc", Uc, (units, units), trainable=trainable_Uc) self.create_variable("bc", bc, (units, ), trainable=trainable_bc) def fn(*args): return self.gate(*args, activation_g, activation_c, activation_h) init = T.stack((init_h, T.zeros(init_h.shape, init_h.dtype))) last, output = T.scan( fn, init=init, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[ self.Wf, self.Uf, self.bf, self.Wi, self.Ui, self.bi, self.Wo, self.Uo, self.bo, self.Wc, self.Uc, self.bc, ], ) if only_last: return last else: return output.transpose((1, 0, 2))
import sys sys.path.insert(0, "../") import symjax as sj import symjax.tensor as T import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') ###### 2D GAUSSIAN EXAMPLE t = T.linspace(-5, 5, 5) x, y = T.meshgrid(t, t) X = T.stack([x.flatten(), y.flatten()], 1) p = T.pdfs.multivariate_normal.pdf(X, T.zeros(2), T.eye(2)) p = p.reshape((5, 5)).round(2) print(p) # Tensor(Op=round_, shape=(5, 5), dtype=float32) # lazy evaluation (not compiled nor optimized) print(p.get()) # [[0. 0. 0. 0. 0. ] # [0. 0. 0.01 0. 0. ] # [0. 0.01 0.16 0.01 0. ] # [0. 0. 0.01 0. 0. ] # [0. 0. 0. 0. 0. ]] # create the function which internall compiles and optimizes
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, var_x=1): alpha = T.Placeholder((1,), 'float32') x = T.Placeholder((Ds[0],), 'float32') X = T.Placeholder((batch_size, Ds[-1]), 'float32') signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32') SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32') m0 = T.Placeholder((batch_size, R), 'float32') m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32') m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32') Ws, vs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)] vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)] var_x = T.Variable(T.ones(Ds[-1]) * var_x) var_z = T.Variable(T.ones(Ds[0])) # create the placeholders Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws] vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs] var_x_ph = T.Placeholder(var_x.shape, var_x.dtype) ############################################################################ # Compute the output of g(x) ############################################################################ maps = [x] xsigns = [] masks = [] for w, v in zip(Ws[:-1], vs[:-1]): pre_activation = T.matmul(w, maps[-1]) + v xsigns.append(T.sign(pre_activation)) masks.append(relu_mask(pre_activation, leakiness)) maps.append(pre_activation * masks[-1]) xsigns = T.concatenate(xsigns) maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1]) ############################################################################ # compute the masks and then the per layer affine mappings ############################################################################ cumulative_units = np.cumsum([0] + Ds[1:]) xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Axs, bxs = get_Abs(Ws, vs, xqs) Aqs, bqs = get_Abs(Ws, vs, qs) AQs, bQs = get_Abs(Ws, vs, Qs) all_bxs = T.hstack(bxs[:-1]).transpose() all_Axs = T.hstack(Axs[:-1])[0] all_bqs = T.hstack(bqs[:-1]).transpose() all_Aqs = T.hstack(Aqs[:-1])[0] x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None] q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None] ############################################################################ # loss (E-step NLL) ############################################################################ Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0) B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0) Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1) ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1) Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]), axis1=1, axis2=2) xAm1Bm0 = X * (Am1 + Bm0) M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\ + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\ + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean()) mean_loss = - (loss + 0.5 * prior) adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs) ############################################################################ # update of var_x ############################################################################ update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\ * T.ones(Ds[-1]) update_varz = M2diag.mean() * T.ones(Ds[0]) ############################################################################ # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ FQ = get_forward(Ws, Qs) update_vs = {} for i in range(len(vs)): if i < len(vs) - 1: # now we forward each bias to the x-space except the ith separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i]) # compute the residual and apply sigma residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\ - T.einsum('nds,Nns->Nnd', AQs[-1], m1) back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0)) whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\ + T.eye(back_error.shape[0]) / (Ds[i] * cov_b) update_vs[vs[i]] = T.linalg.solve(whiten, back_error) else: back_error = (X - (Am1 + Bm0) + vs[-1]) update_vs[vs[i]] = back_error.mean(0) ############################################################################ # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ update_Ws = {} for i in range(len(Ws)): U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i]) if i == 0: V = m2.mean(0) else: V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0) V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2) V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1) Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1]) V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0) whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W) # compute the residual (bottom up) if i == len(Ws) - 1: bottom_up = (X[:, None, :] - vs[-1]) else: if i == 0: residual = (X[:, None, :] - bQs[-1]) else: residual = (X[:, None, :] - bQs[-1]\ + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1])) bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual) # compute the top down vector if i == 0: top_down = m1 else: top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\ T.einsum('nd,Nn->Nnd', bQs[i - 1], m0)) vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size condition = T.diagonal(whiten) update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape) update_Ws[Ws[i]] = update_W ############################################################################ # create the io functions ############################################################################ params = sj.function(outputs = Ws + vs + [var_x]) ll = T.Placeholder((), 'int32') selector = T.one_hot(ll, len(vs)) for i in range(len(vs)): update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\ * selector[i] + vs[i] * (1 - selector[i]) for i in range(len(Ws)): update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\ * selector[i] + Ws[i] * (1 - selector[i]) output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates=adam.updates), 'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = {var_x: update_varx}), 'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_vs), 'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_Ws), 'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]), 'signs2ineq': sj.function(signs, outputs=q_inequalities), 'g': sj.function(x, outputs=maps[-1]), 'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0], bxs[-1][0], x_inequalities, xsigns]), 'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph, updates=dict(zip(Ws + vs + [var_x], Ws_ph + vs_ph + [var_x_ph]))), 'varx': sj.function(outputs=var_x), 'prior': sj.function(outputs=prior), 'varz': sj.function(outputs=var_z), 'params': params, # 'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed), 'input2signs': sj.function(x, outputs=xsigns), 'S' : Ds[0], 'D': Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1, 'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler}} def sample(n): samples = [] for i in range(n): samples.append(output['g'](np.random.randn(Ds[0]))) return np.array(samples) output['sample'] = sample return output
def RNTK_function(N,length,param): DATA = T.Placeholder((N, length), 'float32') RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav']) v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav'] ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return f
def PiecewiseConstant(init, steps_and_values): """piecewise constant variable updating automatically This method allows to obtain a variable with an internal counter that will be updated based on the function updates, whenver this counter reaches one of the step given in the function input then the actual value of the variable becomes the one given for the associated step Args ---- init: float-like the initial value of the variable that will remain as is until a step and update is reached steps_and_values: dict the dictionnary mapping steps-> values, that is, when the number of steps reached one of the given one, the value of the variable becomes the given one associated to the reached step Returns ------- variable: float-like Example ------- .. doctest :: >>> import symjax >>> symjax.current_graph().reset() >>> var = symjax.nn.schedules.PiecewiseConstant(0.1, {4:1, 8:2}) >>> # in the background, symjax automatically records that everytime >>> # a function is using this variable an udnerlying update should occur >>> print(symjax.get_updates()) {Variable(name=step, shape=(), dtype=int32, trainable=False, scope=/PiecewiseConstant/): Op(name=add, fn=add, shape=(), dtype=int32, scope=/PiecewiseConstant/)} >>> # it is up to the user to use it or not, if not used, the internal counter >>> # is never updated and this the variable never changes. >>> # example of use: >>> f = symjax.function(outputs=var, updates=symjax.get_updates()) >>> for i in range(10): ... print(i, f()) 0 0.1 1 0.1 2 0.1 3 0.1 4 1.0 5 1.0 6 1.0 7 1.0 8 2.0 9 2.0 """ with Scope("PiecewiseConstant"): all_steps = T.stack([0] + list(steps_and_values.keys()) + [np.inf]) all_values = T.stack([init] + list(steps_and_values.values()) + [0]) step = T.Variable( 0, trainable=False, name="step", dtype="int32", ) value = all_values[(step >= all_steps).argmin() - 1] return value, step
def RNTK_function(self): print(f"N, {self.N}, length, {self.length}") DATA = T.Placeholder((self.N, self.length), 'float32') RNTK,GP = self.RNTK_first(DATA[:,0]) v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = self.RNTK_output(v) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return RNTK_last,RNTK_avg
losses = list() values = list() for i in range(10): a, b = train() losses.append(a) values.append(b) plt.figure() plt.subplot(121) plt.plot(losses) plt.subplot(122) plt.plot(values, np.zeros_like(values), "kx") ####### jacobians x, y = T.ones(()), T.ones(()) print(x, y) ZZ = T.stack([x, y]) f = T.stack([3 * ZZ[0] + 2 * ZZ[1]], axis=0) j = symjax.jacobians(f, [ZZ])[0] g_j = symjax.function(outputs=j) R = T.random.randn() f = ZZ * 10 * R j = symjax.jacobians(f, [ZZ])[0] g_j = symjax.function(outputs=[j]) for i in range(5): print(g_j()) # plt.show()