def create_fns(input, in_signs, Ds): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) Ws = [sj.initializers.he((j, i)) for j, i in zip(Ds[1:], Ds[:-1])] bs = [sj.initializers.he((j,)) for j in Ds[1:]] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., 0.1) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., 0.1)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) signs = T.concatenate(signs) ineq_b = T.concatenate(B_w[1:-1]) ineq_A = T.vstack(A_w[1:-1]) inequalities = T.hstack([ineq_b[:, None], ineq_A]) inequalities = inequalities * signs[:, None] / T.linalg.norm(ineq_A, 2, 1, keepdims=True) inequalities_code = T.hstack([T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) inequalities_code = inequalities_code * in_signs[:, None] f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g
def get_Abs(Ws, vs, Qs): """produces the pre activation feature maps of layer ell""" n = Qs[-1].shape[0] As = [Ws[0] * T.ones((n, 1, 1))] bs = [vs[0] * T.ones((n, 1))] for i in range(len(Qs)): As.append(T.einsum('db,nb,nbs->nds', Ws[i + 1], Qs[i], As[-1])) bs.append(T.einsum('db,nb,nb->nd', Ws[i + 1], Qs[i], bs[-1])\ + vs[i + 1]) return As, bs
def test_base(): a = T.ones((10, )) b = a.sum() print(b.get()) print(b.get()) f = symjax.function(outputs=b) [f() for i in range(100)]
def test_stack(): u = tt.Variable(tt.ones((2, ))) output = tt.stack([u, 2 * u, 3 * u]) f = symjax.function(outputs=output) assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2))) print(f()) print(f())
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, GLO=False): x = T.Placeholder([batch_size, Ds[-1]], 'float32') z = T.Variable(T.random.randn((batch_size, Ds[0]))) logvar_x = T.Variable(T.ones(1)) # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] h = [z] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) # LOSS prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b if GLO: loss = T.sum((x - h[-1])**2) / batch_size + prior variables = Ws + bs else: loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior variables = Ws + bs prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\ + sum([(w**2).sum() for w in Ws], 0.) / cov_W opti = sj.optimizers.Adam(loss + prior, lr, params=variables) infer = sj.optimizers.Adam(loss, lr, params=[z]) estimate = sj.function(x, outputs=z, updates=infer.updates) train = sj.function(x, outputs=loss, updates=opti.updates) lossf = sj.function(x, outputs=loss) params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)]) output = {'train': train, 'estimate':estimate, 'params':params} output['reset'] = lambda v: z.assign(v) if GLO: output['model'] = 'GLO' else: output['model'] = 'HARD' output['loss'] = lossf output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler} return output
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1): x = T.Placeholder([batch_size, Ds[-1]], 'float32') # ENCODER enc = encoder(x, Ds[0]) mu = enc[-1][:, :Ds[0]] logvar = enc[-1][:, Ds[0]:] var = T.exp(logvar) z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0])) z_ph = T.Placeholder((batch_size, Ds[0]), 'float32') # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] logvar_x = T.Variable(T.zeros(1), name='logvar_x') var_x = T.exp(logvar_x) h, h_ph = [z], [z_ph] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b) h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1]) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1) px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1) loss = - (px + kl).mean() + prior variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x] opti = sj.optimizers.Adam(loss, lr, params=variables) train = sj.function(x, outputs=loss, updates=opti.updates) g = sj.function(z_ph, outputs=h_ph[-1]) params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])]) get_varx = sj.function(outputs = var_x) output = {'train': train, 'g':g, 'params':params} output['model'] = 'VAE' output['varx'] = get_varx output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler, 'prior': sj.function(outputs=prior)} def sample(n): samples = [] for i in range(n // batch_size): samples.append(g(np.random.randn(batch_size, Ds[0]))) return np.concatenate(samples) output['sample'] = sample return output
def forward(self, input, crop_shape, deterministic, padding=0, seed=None): self.crop_shape = crop_shape # if given only a scalar if not hasattr(padding, "__len__"): self.pad_shape = [(padding, padding)] * (input.shape - 1) # else else: self.pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad for pad in padding] assert len(self.pad_shape) == len(self.crop_shape) assert len(self.pad_shape) == (len(input.shape) - 1) self.start_indices = list() self.fixed_indices = list() for i, (pad, dim, crop) in enumerate( zip(self.pad_shape, input.shape[1:], self.crop_shape)): maxval = pad[0] + pad[1] + dim - crop assert maxval >= 0 self.start_indices.append( T.random.randint( minval=0, maxval=maxval, shape=(input.shape[0], 1), dtype="int32", seed=seed + i if seed is not None else seed, )) self.fixed_indices.append( T.ones((input.shape[0], 1), "int32") * (maxval // 2)) self.start_indices = T.concatenate(self.start_indices, 1) self.fixed_indices = T.concatenate(self.fixed_indices, 1) dirac = T.cast(deterministic, "float32") # pad the input pinput = T.pad(input, [(0, 0)] + self.pad_shape) routput = T.stack( [ T.dynamic_slice(pinput[n], self.start_indices[n], self.crop_shape) for n in range(input.shape[0]) ], 0, ) doutput = T.stack( [ T.dynamic_slice(pinput[n], self.fixed_indices[n], self.crop_shape) for n in range(input.shape[0]) ], 0, ) return doutput * dirac + (1 - dirac) * routput
def test_cond3(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u, lambda a, u: a + u, true_inputs=( 2 * T.ones((10, 10)), v, ), false_inputs=( 2 * T.ones((10, 10)), v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 6 * np.ones((10, 10))) assert np.array_equal(f(0), 5 * np.ones((10, 10)))
def test_cond2(): sj.current_graph().reset() v = T.ones((10, 10)) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda u: 4 * u, lambda u: u, true_inputs=(v, ), false_inputs=(2 * v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 4 * np.ones((10, 10))) assert np.array_equal(f(0), 2 * np.ones((10, 10)))
def __init__(self, input, crop_shape, deterministic, padding=0, seed=None): # if given only a scalar if not hasattr(padding, "__len__"): pad_shape = [(padding, padding)] * (input.ndim - 1) # else else: pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad for pad in padding] assert len(pad_shape) == len(crop_shape) assert len(pad_shape) == input.ndim - 1 start_indices = list() fixed_indices = list() for i, (pad, dim, crop) in enumerate(zip(pad_shape, input.shape[1:], crop_shape)): maxval = pad[0] + pad[1] + dim - crop start_indices.append( T.random.randint( minval=0, maxval=maxval, shape=(input.shape[0], 1), dtype="int32", seed=seed + i if seed is not None else seed, )) fixed_indices.append( T.ones((input.shape[0], 1), "int32") * (maxval // 2)) start_indices = T.concatenate(start_indices, 1) fixed_indices = T.concatenate(fixed_indices, 1) dirac = T.cast(deterministic, "float32") # pad the input pinput = T.pad(input, [(0, 0)] + pad_shape) routput = T.map( lambda x, indices: T.dynamic_slice(x, indices, crop_shape), sequences=[pinput, start_indices], ) doutput = T.map( lambda x, indices: T.dynamic_slice(x, indices, crop_shape), sequences=[pinput, fixed_indices], ) return doutput * dirac + (1 - dirac) * routput
def get_forward(Ws, Qs): """this function gives the slope matrix that forwards any pre activation of layer l to the output layer which is :: W^{L}Q^{L}_{\omega}W^{\ell}Q^{\ell} for the \ell element of the returned list. For the first one, is returns the entire A_{\omega} and for the last one it is the identity matrix """ N = Qs[-1].shape[0] L = len(Qs) forward = [T.identity(Ws[-1].shape[0]) * T.ones((N, 1, 1))] for i in range(L): forward.append(T.einsum('ndb,bs,ns->nds', forward[-1], Ws[- 1 - i], Qs[- 1 - i])) return forward[::-1]
def __init__(self, input_or_shape, crop_shape, deterministic, padding=0, seed=None): self.init_input(input_or_shape) self.crop_shape = crop_shape # if given only a scalar if not hasattr(padding, '__len__'): self.pad_shape = [(padding, padding)] * (self.input.shape - 1) # else else: self.pad_shape = [(pad, pad) if not hasattr(pad, '__len__') else pad for pad in padding] assert len(self.pad_shape) == len(self.crop_shape) assert len(self.pad_shape) == (len(self.input.shape) - 1) self.deterministic = deterministic self.start_indices = list() self.fixed_indices = list() for i, (pad, dim, crop) in enumerate( zip(self.pad_shape, self.input.shape[1:], self.crop_shape)): maxval = pad[0] + pad[1] + dim - crop assert maxval >= 0 self.start_indices.append( T.random.randint(minval=0, maxval=maxval, shape=(self.input.shape[0], 1), dtype='int32', seed=seed + i if seed is not None else seed)) self.fixed_indices.append( T.ones((self.input.shape[0], 1), 'int32') * (maxval // 2)) self.start_indices = T.concatenate(self.start_indices, 1) self.fixed_indices = T.concatenate(self.fixed_indices, 1) super().__init__(self.forward(self.input))
def test_cond5(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 W = T.Variable(1) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u[0], lambda a, u: a + u[1], true_inputs=( W, v, ), false_inputs=( W, v, ), ) f = sj.function(u, outputs=out, updates={W: W + 1}) assert np.array_equal(f(1), 3 * np.ones(10)) assert np.array_equal(f(0), 5 * np.ones(10)) assert np.array_equal(f(1), 9 * np.ones(10))
import tensorflow_probability as tfp tfp = tfp.experimental.substrates.jax nn = tfp.distributions.Normal(1.0, 5.0) print(nn.cdf(1)) mean = T.Placeholder((1, ), "float32", name="mean") def inst(self, instance): print("instancecheck", self, instance) return True upgrade_class(mean, T.Tensor, tf.Variable) normal_dist = T.wrap_class(tfp.distributions.Normal) a = normal_dist(mean, 5.0) print(a._loc is mean) asdf x = T.Variable(T.ones(1) - 5) output = a.cdf(x) get_f = sj.function(mean, outputs=output) for i in range(-5, 5): print("mean:", 1, " x:", T.get(x), " cdf:", get_f(i))
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, var_x=1): alpha = T.Placeholder((1,), 'float32') x = T.Placeholder((Ds[0],), 'float32') X = T.Placeholder((batch_size, Ds[-1]), 'float32') signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32') SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32') m0 = T.Placeholder((batch_size, R), 'float32') m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32') m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32') Ws, vs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)] vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)] var_x = T.Variable(T.ones(Ds[-1]) * var_x) var_z = T.Variable(T.ones(Ds[0])) # create the placeholders Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws] vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs] var_x_ph = T.Placeholder(var_x.shape, var_x.dtype) ############################################################################ # Compute the output of g(x) ############################################################################ maps = [x] xsigns = [] masks = [] for w, v in zip(Ws[:-1], vs[:-1]): pre_activation = T.matmul(w, maps[-1]) + v xsigns.append(T.sign(pre_activation)) masks.append(relu_mask(pre_activation, leakiness)) maps.append(pre_activation * masks[-1]) xsigns = T.concatenate(xsigns) maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1]) ############################################################################ # compute the masks and then the per layer affine mappings ############################################################################ cumulative_units = np.cumsum([0] + Ds[1:]) xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Axs, bxs = get_Abs(Ws, vs, xqs) Aqs, bqs = get_Abs(Ws, vs, qs) AQs, bQs = get_Abs(Ws, vs, Qs) all_bxs = T.hstack(bxs[:-1]).transpose() all_Axs = T.hstack(Axs[:-1])[0] all_bqs = T.hstack(bqs[:-1]).transpose() all_Aqs = T.hstack(Aqs[:-1])[0] x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None] q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None] ############################################################################ # loss (E-step NLL) ############################################################################ Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0) B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0) Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1) ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1) Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]), axis1=1, axis2=2) xAm1Bm0 = X * (Am1 + Bm0) M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\ + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\ + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean()) mean_loss = - (loss + 0.5 * prior) adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs) ############################################################################ # update of var_x ############################################################################ update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\ * T.ones(Ds[-1]) update_varz = M2diag.mean() * T.ones(Ds[0]) ############################################################################ # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ FQ = get_forward(Ws, Qs) update_vs = {} for i in range(len(vs)): if i < len(vs) - 1: # now we forward each bias to the x-space except the ith separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i]) # compute the residual and apply sigma residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\ - T.einsum('nds,Nns->Nnd', AQs[-1], m1) back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0)) whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\ + T.eye(back_error.shape[0]) / (Ds[i] * cov_b) update_vs[vs[i]] = T.linalg.solve(whiten, back_error) else: back_error = (X - (Am1 + Bm0) + vs[-1]) update_vs[vs[i]] = back_error.mean(0) ############################################################################ # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ update_Ws = {} for i in range(len(Ws)): U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i]) if i == 0: V = m2.mean(0) else: V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0) V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2) V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1) Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1]) V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0) whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W) # compute the residual (bottom up) if i == len(Ws) - 1: bottom_up = (X[:, None, :] - vs[-1]) else: if i == 0: residual = (X[:, None, :] - bQs[-1]) else: residual = (X[:, None, :] - bQs[-1]\ + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1])) bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual) # compute the top down vector if i == 0: top_down = m1 else: top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\ T.einsum('nd,Nn->Nnd', bQs[i - 1], m0)) vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size condition = T.diagonal(whiten) update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape) update_Ws[Ws[i]] = update_W ############################################################################ # create the io functions ############################################################################ params = sj.function(outputs = Ws + vs + [var_x]) ll = T.Placeholder((), 'int32') selector = T.one_hot(ll, len(vs)) for i in range(len(vs)): update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\ * selector[i] + vs[i] * (1 - selector[i]) for i in range(len(Ws)): update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\ * selector[i] + Ws[i] * (1 - selector[i]) output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates=adam.updates), 'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = {var_x: update_varx}), 'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_vs), 'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_Ws), 'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]), 'signs2ineq': sj.function(signs, outputs=q_inequalities), 'g': sj.function(x, outputs=maps[-1]), 'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0], bxs[-1][0], x_inequalities, xsigns]), 'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph, updates=dict(zip(Ws + vs + [var_x], Ws_ph + vs_ph + [var_x_ph]))), 'varx': sj.function(outputs=var_x), 'prior': sj.function(outputs=prior), 'varz': sj.function(outputs=var_z), 'params': params, # 'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed), 'input2signs': sj.function(x, outputs=xsigns), 'S' : Ds[0], 'D': Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1, 'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler}} def sample(n): samples = [] for i in range(n): samples.append(output['g'](np.random.randn(Ds[0]))) return np.array(samples) output['sample'] = sample return output
#!/usr/bin/env python # -*- coding: utf-8 -*- __author__ = "Randall Balestriero" import symjax.tensor as T import matplotlib.pyplot as plt from symjax.viz import compute_graph x = T.random.randn((10, ), name="x") y = T.random.randn((10, ), name="y") z = T.random.randn((10, ), name="z") w = T.Variable(T.ones(1), name="w") out = (x + y).sum() * w + z.sum() graph = compute_graph(out) graph.draw("file.png", prog="dot") import matplotlib.image as mpimg img = mpimg.imread("file.png") plt.figure(figsize=(15, 5)) imgplot = plt.imshow(img) plt.xticks() plt.yticks() plt.tight_layout()
import symjax import symjax.tensor as T value = T.Variable(T.ones(())) randn = T.random.randn(()) rand = T.random.rand(()) out1 = randn * value out2 = out1.clone({randn: rand}) f = symjax.function(rand, outputs=out2, updates={value: 2 + value}) for i in range(3): print(f(i)) # 0. # 3. # 10. # we create a simple computational graph var = T.Variable(T.random.randn((16, 8), seed=10)) loss = ((var - T.ones_like(var))**2).sum() g = symjax.gradients(loss, [var]) opt = symjax.optimizers.SGD(loss, 0.01, params=var) f = symjax.function(outputs=loss, updates=opt.updates) for i in range(10): print(f()) # 240.96829 # 231.42595 # 222.26149
__author__ = "Randall Balestriero" class product: def __init__(self, W, V=1): self.W = jnp.square(V * W * (W > 0).astype("float32")) self.ndim = self.compute_ndim() def feed(self, x): return jnp.dot(self.W, x) def compute_ndim(self): return self.W.shape[0] * self.W.shape[1] wrapped = T.wrap_class(product, method_exceptions=["compute_ndim"]) a = wrapped(T.zeros((10, 10)), V=T.ones((10, 10))) x = T.random.randn((10, 100)) print(a.W) # (Tensor: name=function[0], shape=(10, 10), dtype=float32) print(a.feed(x)) # Op(name=feed, shape=(10, 100), dtype=float32, scope=/) f = sj.function(outputs=a.feed(x)) f()
import sys sys.path.insert(0, "../") import symjax as sj import symjax.tensor as T import numpy as np __author__ = "Randall Balestriero" # example of cumulative sum def func(carry, x): return carry + 1, 0 output, _ = T.scan(func, T.zeros(1), T.ones(10), length=10) f = sj.function(outputs=output) print(f()) # [10.] # example of simple RNN w = T.Placeholder((3, 10), 'float32') h = T.random.randn((3, 3)) b = T.random.randn((3, )) t_steps = 100 X = T.random.randn((t_steps, 10)) def rnn_cell(carry, x, w):
losses = list() values = list() for i in range(10): a, b = train() losses.append(a) values.append(b) plt.figure() plt.subplot(121) plt.plot(losses) plt.subplot(122) plt.plot(values, np.zeros_like(values), "kx") ####### jacobians x, y = T.ones(()), T.ones(()) print(x, y) ZZ = T.stack([x, y]) f = T.stack([3 * ZZ[0] + 2 * ZZ[1]], axis=0) j = symjax.jacobians(f, [ZZ])[0] g_j = symjax.function(outputs=j) R = T.random.randn() f = ZZ * 10 * R j = symjax.jacobians(f, [ZZ])[0] g_j = symjax.function(outputs=[j]) for i in range(5): print(g_j()) # plt.show()
import jax import numpy as np import sys sys.path.insert(0, "../") import symjax import symjax.tensor as T # map xx = T.ones(10) a = T.map(lambda a: a * 2, xx) g = symjax.gradients(a.sum(), xx)[0] f = symjax.function(outputs=[a, g]) # scan xx = T.ones(10) * 2 a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx) g = symjax.gradients(a[1][-1], xx)[0] f = symjax.function(outputs=[a, g]) # scan with updates xx = T.range(5) uu = T.ones((10, 2)) vvar = T.Variable(T.zeros((10, 2))) vv = T.index_add(vvar, 1, 1) a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv]) #a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx) #a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f())
import jax import networkx as nx import matplotlib.pyplot as plt bin_rules = [ [ lambda *args: args[0] == 0 or args[1] == 0, jax.numpy.add, lambda *args: args[0] if args[1] == 0 else args[0], ], [lambda *args: len(args) == 1, jax.numpy.add, lambda *args: 2 * args[0]], [lambda *args: args[1] == 1, jax.numpy.true_divide, lambda *args: args[0]], [ lambda *args: len(args) == 1, jax.numpy.true_divide, lambda *args: T.ones(args[0].shape, args[0].dtype), ], [ lambda *args: args[0] == 1 or args[1] == 1, jax.numpy.multiply, lambda *args: args[0] if args[1] == 1 else args[1], ], ] def simplify_add(graph): to_search = list(graph.nodes.keys()) while len(to_search): j = to_search[-1] if type(j) == T.Op:
import symjax import symjax.tensor as T g = symjax.Graph("model1") with g: learning_rate = T.Variable(T.ones((1, ))) with symjax.Graph("layer1"): W1 = T.Variable(T.zeros((1, )), name="W") b1 = T.Variable(T.zeros((1, )), name="b") with symjax.Graph("layer2"): W2 = T.Variable(T.zeros((1, )), name="W") b2 = T.Variable(T.zeros((1, )), name="b") # define an irrelevant loss function involving the parameters loss = (W1 + b1 + W2 + b2) * learning_rate # and a train/update function train = symjax.function(outputs=loss, updates={ W1: W1 + 1, b1: b1 + 2, W2: W2 + 2, b2: b2 + 3 }) # pretend we train for a while for i in range(4): print(train()) # [0.] # [8.]
def create_fns(input, in_signs, Ds, x, m0, m1, m2, batch_in_signs, alpha=0.1, sigma=1, sigma_x=1, lr=0.0002): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) BS = batch_in_signs.shape[0] Ws = [ T.Variable(sj.initializers.glorot((j, i)) * sigma) for j, i in zip(Ds[1:], Ds[:-1]) ] bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\ + [T.Variable(T.zeros((Ds[-1],)))] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))] batch_B_q = [T.zeros((BS, Ds[0]))] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha) batch_in_masks = T.where( T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., alpha)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) batch_A_q.append( T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1])) batch_B_q.append((w * batch_in_masks[:, None, start:end]\ * batch_B_q[-1][:, None, :]).sum(2) + b) batch_B_q = batch_B_q[-1] batch_A_q = batch_A_q[-1] signs = T.concatenate(signs) inequalities = T.hstack( [T.concatenate(B_w[1:-1])[:, None], T.vstack(A_w[1:-1])]) * signs[:, None] inequalities_code = T.hstack( [T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) * in_signs[:, None] #### loss log_sigma2 = T.Variable(sigma_x) sigma2 = T.exp(log_sigma2) Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1) Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0) B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2) AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2) inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2)) loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze() loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze() cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi) loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\ + 0.5 * loss_2 / sigma2 + 0.5 * loss_z mean_loss = loss.mean() adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9) train_f = sj.function(batch_in_signs, x, m0, m1, m2, outputs=mean_loss, updates=adam.updates) f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g, train_f, sigma2
import symjax import symjax.tensor as T # scope/graph naming and accessing value1 = T.Variable(T.ones((1,))) value2 = T.Variable(T.zeros((1,))) g = symjax.Graph("special") with g: value3 = T.Variable(T.zeros((1,))) value4 = T.Variable(T.zeros((1,))) result = value3 + value4 h = symjax.Graph("inversion") with h: value5 = T.Variable(T.zeros((1,))) value6 = T.Variable(T.zeros((1,))) value7 = T.Variable(T.zeros((1,)), name="w") print(g.variables) # {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/), # 'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/)} print(h.variables) # {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/), # 'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/), # 'w': Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)} print(h.variable("w"))