def create_fns(input, in_signs, Ds): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) Ws = [sj.initializers.he((j, i)) for j, i in zip(Ds[1:], Ds[:-1])] bs = [sj.initializers.he((j,)) for j in Ds[1:]] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., 0.1) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., 0.1)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) signs = T.concatenate(signs) ineq_b = T.concatenate(B_w[1:-1]) ineq_A = T.vstack(A_w[1:-1]) inequalities = T.hstack([ineq_b[:, None], ineq_A]) inequalities = inequalities * signs[:, None] / T.linalg.norm(ineq_A, 2, 1, keepdims=True) inequalities_code = T.hstack([T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) inequalities_code = inequalities_code * in_signs[:, None] f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g
def RNTK_first(x,sw,su,sb,sh,L,Lf,sv): X = x*x[:, None] n = X.shape[0] # gp_new = T.expand_dims(sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2) * X + sb ** 2, axis = 0) rntk_new = gp_new for l in range(L-1): l = l+1 S_new,D_new = VT(gp_new[l-1]) gp_new = T.concatenate([gp_new,T.expand_dims(sh ** 2 * sw ** 2 * T.eye(n, n) + su**2 * S_new + sb**2,axis = 0)]) rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[l] + (Lf <= (l-1))*su**2*rntk_new[l-1]*D_new,axis = 0)]) S_old,D_old = VT(gp_new[L-1]) gp_new = T.concatenate([gp_new,T.expand_dims(sv**2*S_old,axis = 0)]) rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[L] + (Lf != L)*sv**2*rntk_new[L-1]*D_old,axis = 0)]) return rntk_new, gp_new
def make_boundary_condition(self, X): bc = self.sh**2 * self.sw**2 * T.eye( self.N, self.N) + (self.su**2) * X + self.sb**2 ## took out X || # single_boundary_condition = T.expand_dims(bc, axis = 0) # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0) # boundary_condition = T.concatenate([single_boundary_condition, single_boundary_condition]) return bc
def create_func_for_diag(self, dim1idx, dim2idx, function=False, jmode=False): diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode) # print('test') ## prev_vals - (2,1) - previous phi and lambda values ## idx - where we are on the diagonal ## d1idx - y value of first dimension diag start ## d2idx - x value of second dimension diag start ## d1ph - max value of first dimension ## d2ph - max value of second dimension bc = self.sh**2 * self.sw**2 * T.eye( self.n, self.n) + (self.su**2) * self.X + self.sb**2 single_boundary_condition = T.expand_dims(bc, axis=0) # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0) boundary_condition = T.concatenate( [single_boundary_condition, single_boundary_condition]) #one for phi and lambda def fn(prev_vals, idx, Xph): ## change - xph must now index the dataset instead of being passed in # tiprime_iter = d1idx + idx # ti_iter = d2idx + idx prev_lambda = prev_vals[0] prev_phi = prev_vals[1] ## not boundary condition S, D = self.VT(prev_lambda) new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2 ## took out an X new_phi = new_lambda + self.sw**2 * prev_phi * D lambda_expanded = T.expand_dims(new_lambda, axis=0) phi_expanded = T.expand_dims(new_phi, axis=0) to_return = T.concatenate([lambda_expanded, phi_expanded]) # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None) return to_return, to_return last_ema, all_ema = T.scan(fn, init=boundary_condition, sequences=[diag], non_sequences=[self.X]) expanded_ema = T.concatenate( [T.expand_dims(boundary_condition, axis=0), all_ema]) print(expanded_ema) if function: f = symjax.function(diag, outputs=expanded_ema) return f else: return expanded_ema
def RNTK_first_time_step(x, param): # this is for computing the first GP and RNTK for t = 1. Both for relu and erf sw = param["sigmaw"] su = param["sigmau"] sb = param["sigmab"] sh = param["sigmah"] X = x * x[:, None] print(X) n = X.shape[0] GP_new = sh**2 * sw**2 * T.eye(n, n) + (su**2 / m) * X + sb**2 RNTK_new = GP_new return RNTK_new, GP_new
def create_func_for_diag(self): # NEW_DATA = self.reorganize_data() ## change - bc should be a function that takes a passed in X? bc = self.sh**2 * self.sw**2 * T.eye(self.n, self.n) + ( self.su**2) * self.X + self.sb**2 ## took out X || single_boundary_condition = T.expand_dims(bc, axis=0) # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0) boundary_condition = T.concatenate( [single_boundary_condition, single_boundary_condition]) self.boundary_condition = boundary_condition # self.save_vts = {} ## prev_vals - (2,1) - previous phi and lambda values ## idx - where we are on the diagonal def fn(prev_vals, idx, Xph): ## change - xph must now index the dataset instead of being passed in # x = Xph['indexed'] # X = x*x[:, None] # tiprime_iter = d1idx + idx # ti_iter = d2idx + idx prev_lambda = prev_vals[0] prev_phi = prev_vals[1] ## not boundary condition S, D = self.VT(prev_lambda) new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2 ## took out an X new_phi = new_lambda + self.sw**2 * prev_phi * D lambda_expanded = T.expand_dims(new_lambda, axis=0) phi_expanded = T.expand_dims(new_phi, axis=0) to_return = T.concatenate([lambda_expanded, phi_expanded]) if idx in self.ends_of_calced_diags: # self.save_vts[idx] = [S,D] return boundary_condition, to_return return to_return, to_return last_ema, all_ema = T.scan( fn, init=boundary_condition, sequences=[jnp.arange(0, sum(self.dim_lengths) - self.dim_num)], non_sequences=[self.X]) # if fbool: # return all_ema, f return self.compute_kernels(all_ema)
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1): # gaussian parameters freqs = get_scaled_freqs(f0, f1, J) freqs *= (J - 1) * 10 if modes > 1: other_modes = np.random.randint(0, J, J * (modes - 1)) freqs = np.concatenate([freqs, freqs[other_modes]]) # crate the average vectors mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1) mu = T.Variable(mu_init.astype('float32'), name='mu') # create the covariance matrix cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'), name='cor') sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)], 1) sigma = T.Variable(sigma_init.astype('float32'), name='sigma') # create the mixing coefficients mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32')) # now apply our parametrization coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95 Id = T.eye(2) cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\ T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1)) cov_inv = T.linalg.inv(cov) # get the gaussian filters time = T.linspace(-5, 5, M) freq = T.linspace(0, J * 10, N) x, y = T.meshgrid(time, freq) grid = T.stack([y.flatten(), x.flatten()], 1) centered = grid - T.expand_dims(mu, 1) # asdf gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1)) norm = T.linalg.norm(gaussian, 2, 1, keepdims=True) gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M)) return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
def RNTK_first(self,x): # alg 1, line 1 X = x*x[:, None] n = X.shape[0] # // creates a diagonal matrix of sh^2 * sw^2 test = self.sh ** 2 * self.sw ** 2 * T.eye(n, n) + (self.su ** 2) * X + self.sb ** 2 gp_new = T.expand_dims(test, axis = 0) # line 2, alg 1 #GP IS GAMMA, RNTK IS PHI rntk_new = gp_new print("gp_new 1", gp_new) print("rntk_new 1", rntk_new) for l in range(self.L-1): #line 3, alg 1 l = l+1 print("gp_new", gp_new[l-1]) S_new,D_new = self.VT(gp_new[l-1]) gp_new = T.concatenate([gp_new,T.expand_dims(self.sh ** 2 * self.sw ** 2 * T.eye(n, n) + self.su**2 * S_new + self.sb**2,axis = 0)]) #line 4, alg 1 rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[l] + (self.Lf <= (l-1))*self.su**2*rntk_new[l-1]*D_new,axis = 0)]) S_old,D_old = self.VT(gp_new[self.L-1]) gp_new = T.concatenate([gp_new,T.expand_dims(self.sv**2*S_old,axis = 0)]) #line 5, alg 1 rntk_new = T.concatenate([rntk_new,T.expand_dims(gp_new[self.L] + (self.Lf != self.L)*self.sv**2*rntk_new[self.L-1]*D_old,axis = 0)]) print("gp_new 2", gp_new) print("rntk_new 2", rntk_new) return rntk_new, gp_new
def create_fns(input, in_signs, Ds, x, m0, m1, m2, batch_in_signs, alpha=0.1, sigma=1, sigma_x=1, lr=0.0002): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) BS = batch_in_signs.shape[0] Ws = [ T.Variable(sj.initializers.glorot((j, i)) * sigma) for j, i in zip(Ds[1:], Ds[:-1]) ] bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\ + [T.Variable(T.zeros((Ds[-1],)))] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))] batch_B_q = [T.zeros((BS, Ds[0]))] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha) batch_in_masks = T.where( T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., alpha)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) batch_A_q.append( T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1])) batch_B_q.append((w * batch_in_masks[:, None, start:end]\ * batch_B_q[-1][:, None, :]).sum(2) + b) batch_B_q = batch_B_q[-1] batch_A_q = batch_A_q[-1] signs = T.concatenate(signs) inequalities = T.hstack( [T.concatenate(B_w[1:-1])[:, None], T.vstack(A_w[1:-1])]) * signs[:, None] inequalities_code = T.hstack( [T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) * in_signs[:, None] #### loss log_sigma2 = T.Variable(sigma_x) sigma2 = T.exp(log_sigma2) Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1) Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0) B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2) AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2) inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2)) loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze() loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze() cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi) loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\ + 0.5 * loss_2 / sigma2 + 0.5 * loss_z mean_loss = loss.mean() adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9) train_f = sj.function(batch_in_signs, x, m0, m1, m2, outputs=mean_loss, updates=adam.updates) f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g, train_f, sigma2
sys.path.insert(0, "../") import symjax as sj import symjax.tensor as T import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') ###### 2D GAUSSIAN EXAMPLE t = T.linspace(-5, 5, 5) x, y = T.meshgrid(t, t) X = T.stack([x.flatten(), y.flatten()], 1) p = T.pdfs.multivariate_normal.pdf(X, T.zeros(2), T.eye(2)) p = p.reshape((5, 5)).round(2) print(p) # Tensor(Op=round_, shape=(5, 5), dtype=float32) # lazy evaluation (not compiled nor optimized) print(p.get()) # [[0. 0. 0. 0. 0. ] # [0. 0. 0.01 0. 0. ] # [0. 0.01 0.16 0.01 0. ] # [0. 0. 0.01 0. 0. ] # [0. 0. 0. 0. 0. ]] # create the function which internall compiles and optimizes # the function does not take any arguments and only outputs the
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, var_x=1): alpha = T.Placeholder((1,), 'float32') x = T.Placeholder((Ds[0],), 'float32') X = T.Placeholder((batch_size, Ds[-1]), 'float32') signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32') SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32') m0 = T.Placeholder((batch_size, R), 'float32') m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32') m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32') Ws, vs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)] vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)] var_x = T.Variable(T.ones(Ds[-1]) * var_x) var_z = T.Variable(T.ones(Ds[0])) # create the placeholders Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws] vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs] var_x_ph = T.Placeholder(var_x.shape, var_x.dtype) ############################################################################ # Compute the output of g(x) ############################################################################ maps = [x] xsigns = [] masks = [] for w, v in zip(Ws[:-1], vs[:-1]): pre_activation = T.matmul(w, maps[-1]) + v xsigns.append(T.sign(pre_activation)) masks.append(relu_mask(pre_activation, leakiness)) maps.append(pre_activation * masks[-1]) xsigns = T.concatenate(xsigns) maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1]) ############################################################################ # compute the masks and then the per layer affine mappings ############################################################################ cumulative_units = np.cumsum([0] + Ds[1:]) xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Axs, bxs = get_Abs(Ws, vs, xqs) Aqs, bqs = get_Abs(Ws, vs, qs) AQs, bQs = get_Abs(Ws, vs, Qs) all_bxs = T.hstack(bxs[:-1]).transpose() all_Axs = T.hstack(Axs[:-1])[0] all_bqs = T.hstack(bqs[:-1]).transpose() all_Aqs = T.hstack(Aqs[:-1])[0] x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None] q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None] ############################################################################ # loss (E-step NLL) ############################################################################ Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0) B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0) Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1) ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1) Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]), axis1=1, axis2=2) xAm1Bm0 = X * (Am1 + Bm0) M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\ + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\ + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean()) mean_loss = - (loss + 0.5 * prior) adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs) ############################################################################ # update of var_x ############################################################################ update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\ * T.ones(Ds[-1]) update_varz = M2diag.mean() * T.ones(Ds[0]) ############################################################################ # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ FQ = get_forward(Ws, Qs) update_vs = {} for i in range(len(vs)): if i < len(vs) - 1: # now we forward each bias to the x-space except the ith separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i]) # compute the residual and apply sigma residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\ - T.einsum('nds,Nns->Nnd', AQs[-1], m1) back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0)) whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\ + T.eye(back_error.shape[0]) / (Ds[i] * cov_b) update_vs[vs[i]] = T.linalg.solve(whiten, back_error) else: back_error = (X - (Am1 + Bm0) + vs[-1]) update_vs[vs[i]] = back_error.mean(0) ############################################################################ # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ update_Ws = {} for i in range(len(Ws)): U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i]) if i == 0: V = m2.mean(0) else: V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0) V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2) V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1) Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1]) V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0) whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W) # compute the residual (bottom up) if i == len(Ws) - 1: bottom_up = (X[:, None, :] - vs[-1]) else: if i == 0: residual = (X[:, None, :] - bQs[-1]) else: residual = (X[:, None, :] - bQs[-1]\ + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1])) bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual) # compute the top down vector if i == 0: top_down = m1 else: top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\ T.einsum('nd,Nn->Nnd', bQs[i - 1], m0)) vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size condition = T.diagonal(whiten) update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape) update_Ws[Ws[i]] = update_W ############################################################################ # create the io functions ############################################################################ params = sj.function(outputs = Ws + vs + [var_x]) ll = T.Placeholder((), 'int32') selector = T.one_hot(ll, len(vs)) for i in range(len(vs)): update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\ * selector[i] + vs[i] * (1 - selector[i]) for i in range(len(Ws)): update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\ * selector[i] + Ws[i] * (1 - selector[i]) output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates=adam.updates), 'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = {var_x: update_varx}), 'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_vs), 'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_Ws), 'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]), 'signs2ineq': sj.function(signs, outputs=q_inequalities), 'g': sj.function(x, outputs=maps[-1]), 'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0], bxs[-1][0], x_inequalities, xsigns]), 'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph, updates=dict(zip(Ws + vs + [var_x], Ws_ph + vs_ph + [var_x_ph]))), 'varx': sj.function(outputs=var_x), 'prior': sj.function(outputs=prior), 'varz': sj.function(outputs=var_z), 'params': params, # 'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed), 'input2signs': sj.function(x, outputs=xsigns), 'S' : Ds[0], 'D': Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1, 'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler}} def sample(n): samples = [] for i in range(n): samples.append(output['g'](np.random.randn(Ds[0]))) return np.array(samples) output['sample'] = sample return output