def __init__( self, sequence, init_h, units, W=initializers.glorot_uniform, H=initializers.orthogonal, b=T.zeros, trainable_W=True, trainable_H=True, trainable_b=True, activation=nn.sigmoid, only_last=False, ): W = create_variable("W", W, (sequence.shape[2], units), trainable=trainable_W) H = create_variable("H", H, (units, units), trainable=trainable_H) b = create_variable("b", b, (units, ), trainable=trainable_b) last, output = T.scan( lambda h, x, W, H, b: RNN.gate(h, x, W, H, b, activation), init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[W, H, b], ) if only_last: return last else: return output.transpose((1, 0, 2))
def forward( self, sequence, init_h, units, W=initializers.he, H=initializers.he, b=T.zeros, trainable_W=True, trainable_H=True, trainable_b=True, activation=nn.sigmoid, only_last=False, ): self.create_variable("W", W, (sequence.shape[2], units), trainable=trainable_W) self.create_variable("H", H, (units, units), trainable=trainable_H) self.create_variable("b", b, (units), trainable=trainable_b) last, output = T.scan( lambda h, x, W, H, b: self.gate(h, x, W, H, b, activation), init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[self.W, self.H, self.b], ) if only_last: return last else: return output.transpose((1, 0, 2))
def compute_kernels(self, final_ema): diag_ends = self.get_ends_of_diags(final_ema) S_init, D_init = self.VT(diag_ends[0][0]) init_Kappa = self.sv**2 * S_init init_Theta = init_Kappa + self.sv**2 * diag_ends[0][1] * D_init init_list = T.concatenate([ T.expand_dims(init_Kappa, axis=0), T.expand_dims(init_Theta, axis=0) ]) def map_test(gp_rntk_sum, gp_rntk): S, D = self.VT(gp_rntk[0]) ret1 = self.sv**2 * S ret2 = ret1 + self.sv**2 * gp_rntk[1] * D gp_rntk_sum = T.index_add(gp_rntk_sum, 0, ret1) gp_rntk_sum = T.index_add(gp_rntk_sum, 1, ret2) return gp_rntk_sum, gp_rntk_sum final_K_T, inter_results = T.scan(map_test, init=init_list, sequences=[diag_ends[1:]]) return final_K_T
def RNTK_function(self): print(f"N, {self.N}, length, {self.length}") DATA = T.Placeholder((self.N, self.length), 'float32') RNTK,GP = self.RNTK_first(DATA[:,0]) v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = self.RNTK_output(v) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return RNTK_last,RNTK_avg
def RNTK_function(N,length,param): DATA = T.Placeholder((N, length), 'float32') RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav']) v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav'] ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return f
def create_func_for_diag(self, dim1idx, dim2idx, function=False, jmode=False): diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode) # print('test') ## prev_vals - (2,1) - previous phi and lambda values ## idx - where we are on the diagonal ## d1idx - y value of first dimension diag start ## d2idx - x value of second dimension diag start ## d1ph - max value of first dimension ## d2ph - max value of second dimension bc = self.sh**2 * self.sw**2 * T.eye( self.n, self.n) + (self.su**2) * self.X + self.sb**2 single_boundary_condition = T.expand_dims(bc, axis=0) # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0) boundary_condition = T.concatenate( [single_boundary_condition, single_boundary_condition]) #one for phi and lambda def fn(prev_vals, idx, Xph): ## change - xph must now index the dataset instead of being passed in # tiprime_iter = d1idx + idx # ti_iter = d2idx + idx prev_lambda = prev_vals[0] prev_phi = prev_vals[1] ## not boundary condition S, D = self.VT(prev_lambda) new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2 ## took out an X new_phi = new_lambda + self.sw**2 * prev_phi * D lambda_expanded = T.expand_dims(new_lambda, axis=0) phi_expanded = T.expand_dims(new_phi, axis=0) to_return = T.concatenate([lambda_expanded, phi_expanded]) # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None) return to_return, to_return last_ema, all_ema = T.scan(fn, init=boundary_condition, sequences=[diag], non_sequences=[self.X]) expanded_ema = T.concatenate( [T.expand_dims(boundary_condition, axis=0), all_ema]) print(expanded_ema) if function: f = symjax.function(diag, outputs=expanded_ema) return f else: return expanded_ema
def create_func_for_diag(self): # NEW_DATA = self.reorganize_data() ## change - bc should be a function that takes a passed in X? bc = self.sh**2 * self.sw**2 * T.eye(self.n, self.n) + ( self.su**2) * self.X + self.sb**2 ## took out X || single_boundary_condition = T.expand_dims(bc, axis=0) # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0) boundary_condition = T.concatenate( [single_boundary_condition, single_boundary_condition]) self.boundary_condition = boundary_condition # self.save_vts = {} ## prev_vals - (2,1) - previous phi and lambda values ## idx - where we are on the diagonal def fn(prev_vals, idx, Xph): ## change - xph must now index the dataset instead of being passed in # x = Xph['indexed'] # X = x*x[:, None] # tiprime_iter = d1idx + idx # ti_iter = d2idx + idx prev_lambda = prev_vals[0] prev_phi = prev_vals[1] ## not boundary condition S, D = self.VT(prev_lambda) new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2 ## took out an X new_phi = new_lambda + self.sw**2 * prev_phi * D lambda_expanded = T.expand_dims(new_lambda, axis=0) phi_expanded = T.expand_dims(new_phi, axis=0) to_return = T.concatenate([lambda_expanded, phi_expanded]) if idx in self.ends_of_calced_diags: # self.save_vts[idx] = [S,D] return boundary_condition, to_return return to_return, to_return last_ema, all_ema = T.scan( fn, init=boundary_condition, sequences=[jnp.arange(0, sum(self.dim_lengths) - self.dim_num)], non_sequences=[self.X]) # if fbool: # return all_ema, f return self.compute_kernels(all_ema)
def compute_q(self, DATA): DATAT = T.transpose(DATA) xz = DATAT[0] # print(xz) init = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2 # print("init", init) # init = self.su * 2 * T.linalg.norm(xz, ord = 2) + self.sb**2 + self.sh**2 #make this a vectorized function def scan_func(prevq, MINIDATAT): #MINIDATAT shuold be a vector of lenght N # print(MINIDATAT) # the trick to this one is to use the original VT S = self.alg1_VT(prevq) # -> M is K3 # S = prevq # newq = self.su * 2 * T.linalg.norm(T.expand_dims(xz, axis = 0), ord = 2, axis = 0) + self.sb**2 + self.sh**2 newq = self.sw*2 * S + self.su * 2 * T.linalg.norm(T.expand_dims(MINIDATAT, axis = 0), ord = 2, axis = 0) + self.sb**2 # print("newq", newq) return newq, newq last_ema, all_ema = T.scan(scan_func, init = init, sequences = [DATAT[1:]]) return T.concatenate([T.expand_dims(init, axis = 0), all_ema])
# this function should output the new value of the carry as well as an # additional output, in our case, the carry (EMA) is also what we want to # output at each tiem step def fn(at, xt, alpha): # the function first input is the carry, then are the (ordered) # values from sequences and non_sequences similar to Theano EMA = at * alpha + (1 - alpha) * xt return EMA, EMA # the scan function will return the carry at each time steps (first arg.) # as well as the last one, we also need to provide an init. last_ema, all_ema = T.scan( fn, init=signal[0], sequences=[signal[1:]], non_sequences=[alpha] ) f = symjax.function(signal, alpha, outputs=all_ema) # generate a signal x = np.cos(np.linspace(-3, 3, 512)) + np.random.randn(512) * 0.2 fig, ax = plt.subplots(3, 1, figsize=(3, 9)) for k, alpha in enumerate([0.1, 0.5, 0.9]): ax[k].plot(x, c="b") ax[k].plot(f(x, alpha), c="r") ax[k].set_title("EMA: {}".format(alpha)) ax[k].set_xticks([])
def __init__( self, sequence, init_h, units, Wf=initializers.glorot_uniform, Uf=initializers.orthogonal, bf=T.zeros, Wi=initializers.glorot_uniform, Ui=initializers.orthogonal, bi=T.zeros, Wo=initializers.glorot_uniform, Uo=initializers.orthogonal, bo=T.zeros, Wc=initializers.glorot_uniform, Uc=initializers.orthogonal, bc=T.zeros, trainable_Wf=True, trainable_Uf=True, trainable_bf=True, trainable_Wi=True, trainable_Ui=True, trainable_bi=True, trainable_Wo=True, trainable_Uo=True, trainable_bo=True, trainable_Wc=True, trainable_Uc=True, trainable_bc=True, activation_g=nn.sigmoid, activation_c=T.tanh, activation_h=T.tanh, only_last=False, gate="minimal", ): self.create_variable("Wf", Wf, (sequence.shape[2], units), trainable=trainable_Wf) self.create_variable("Uf", Uf, (units, units), trainable=trainable_Uf) self.create_variable("bf", bf, (units, ), trainable=trainable_bf) self.create_variable("Wi", Wi, (sequence.shape[2], units), trainable=trainable_Wi) self.create_variable("Ui", Ui, (units, units), trainable=trainable_Ui) self.create_variable("bi", bi, (units, ), trainable=trainable_bi) self.create_variable("Wo", Wo, (sequence.shape[2], units), trainable=trainable_Wo) self.create_variable("Uo", Uo, (units, units), trainable=trainable_Uo) self.create_variable("bo", bo, (units, ), trainable=trainable_bo) self.create_variable("Wc", Wc, (sequence.shape[2], units), trainable=trainable_Wc) self.create_variable("Uc", Uc, (units, units), trainable=trainable_Uc) self.create_variable("bc", bc, (units, ), trainable=trainable_bc) def fn(*args): return self.gate(*args, activation_g, activation_c, activation_h) init = T.stack((init_h, T.zeros(init_h.shape, init_h.dtype))) last, output = T.scan( fn, init=init, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[ self.Wf, self.Uf, self.bf, self.Wi, self.Ui, self.bi, self.Wo, self.Uo, self.bo, self.Wc, self.Uc, self.bc, ], ) if only_last: return last else: return output.transpose((1, 0, 2))
def __init__( self, sequence, init_h, units, Wh=initializers.glorot_uniform, Uh=initializers.orthogonal, bh=T.zeros, Wz=initializers.glorot_uniform, Uz=initializers.orthogonal, bz=T.zeros, Wr=initializers.glorot_uniform, Ur=initializers.orthogonal, br=T.zeros, trainable_Wh=True, trainable_Uh=True, trainable_bh=True, trainable_Wz=True, trainable_Uz=True, trainable_bz=True, trainable_Wr=True, trainable_Ur=True, trainable_br=True, activation=nn.sigmoid, phi=T.tanh, only_last=False, gate="minimal", ): Wh = create_variable("Wh", Wh, (sequence.shape[2], units), trainable=trainable_Wh) Uh = create_variable("Uh", Uh, (units, units), trainable=trainable_Uh) bh = create_variable("bh", bh, (units, ), trainable=trainable_bh) Wz = create_variable("Wz", Wz, (sequence.shape[2], units), trainable=trainable_Wz) Uz = create_variable("Uz", Uz, (units, units), trainable=trainable_Uz) bz = create_variable("bz", bz, (units, ), trainable=trainable_bz) if gate == "full": Wr = create_variable("Wr", Wr, (sequence.shape[2], units), trainable=trainable_Wr) Ur = create_variable("Ur", Ur, (units, units), trainable=trainable_Ur) br = create_variable("br", br, (units, ), trainable=trainable_br) if gate == "minimal": def fn(*args): return GRU.minimal_gate(*args, activation, phi) last, output = T.scan( fn, init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[Wh, Uh, bh, Wz, Uz, bz], ) elif gate == "full": def fn(*args): return GRU.full_gate(*args, activation, phi) last, output = T.scan( fn, init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[Wh, Uh, bh, Wz, Uz, bz, Wr, Ur, br], ) if only_last: return last else: return output.transpose((1, 0, 2))
import numpy as np import sys sys.path.insert(0, "../") import symjax import symjax.tensor as T # map xx = T.ones(10) a = T.map(lambda a: a * 2, xx) g = symjax.gradients(a.sum(), xx)[0] f = symjax.function(outputs=[a, g]) # scan xx = T.ones(10) * 2 a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx) g = symjax.gradients(a[1][-1], xx)[0] f = symjax.function(outputs=[a, g]) # scan with updates xx = T.range(5) uu = T.ones((10, 2)) vvar = T.Variable(T.zeros((10, 2))) vv = T.index_add(vvar, 1, 1) a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv]) #a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx) #a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f()) asdf
def forward( self, sequence, init_h, units, Wh=initializers.he, Uh=initializers.he, bh=T.zeros, Wz=initializers.he, Uz=initializers.he, bz=T.zeros, Wr=initializers.he, Ur=initializers.he, br=T.zeros, trainable_Wh=True, trainable_Uh=True, trainable_bh=True, trainable_Wz=True, trainable_Uz=True, trainable_bz=True, trainable_Wr=True, trainable_Ur=True, trainable_br=True, activation=nn.sigmoid, phi=T.tanh, only_last=False, gate="minimal", ): self.create_variable("Wh", Wh, (sequence.shape[2], units), trainable=trainable_Wh) self.create_variable("Uh", Uh, (units, units), trainable=trainable_Uh) self.create_variable("bh", bh, (units), trainable=trainable_bh) self.create_variable("Wz", Wz, (sequence.shape[2], units), trainable=trainable_Wz) self.create_variable("Uz", Uz, (units, units), trainable=trainable_Uz) self.create_variable("bz", bz, (units), trainable=trainable_bz) if gate == "full": self.create_variable("Wr", Wr, (sequence.shape[2], units), trainable=trainable_Wr) self.create_variable("Ur", Ur, (units, units), trainable=trainable_Ur) self.create_variable("br", br, (units), trainable=trainable_br) if gate == "minimal": def fn(h, x, Wh, Uh, bh, Wz, Uz, bz): return self.minimal_gate(h, x, Wh, Uh, bh, Wz, Uz, bz, activation, phi) elif gate == "full": def fn(h, x, Wh, Uh, bh, Wz, Uz, bz): return self.full_gate(h, x, Wh, Uh, bh, Wz, Uz, bz, Wr, Ur, br, activation, phi) last, output = T.scan( fn, init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[self.W, self.H, self.b], ) if only_last: return last else: return output.transpose((1, 0, 2))
def create_func_for_diag(self): NEW_DATA = self.reorganize_data() NEW_DATA_ATTACHED = jnp.array(list(zip(NEW_DATA[:-1], NEW_DATA[1:]))) # print(NEW_DATA_ATTACHED) x = self.DATA[:,0] X = x*x[:, None] boundary_condition = self.make_boundary_condition(X) # #lets create the inital kernels - should be starting with top right temp_K, temp_theta = self.compute_kernels(self.qt, self.qtprime, 0, self.dim_1, boundary_condition, boundary_condition, T.empty((self.N, self.N)), T.empty((self.N, self.N))) init_K, init_theta = self.compute_kernels(self.qt, self.qtprime, 0, self.dim_1 - 1, boundary_condition, boundary_condition, temp_K, temp_theta) initial_conditions = create_T_list([boundary_condition, boundary_condition, init_K, init_theta]) # initial_conditions = create_T_list([T.empty((self.N, self.N)), T.empty((self.N, self.N)), T.empty((self.N, self.N)), T.empty((self.N, self.N)) + 2]) ## prev_vals - (4,self.N,self.N) - previous phi, lambda, and the two kernel values ## idx - where we are on the diagonal def fn(prev_vals, idx, data_idxs, DATAPH, DATAPRIMEPH, qtph, qtprimeph): xTP = DATAPRIMEPH[data_idxs[0][0]] #N1 xT = DATAPH[data_idxs[0][1]] #N2 xINNER = T.inner(xT, xTP) #N1 x N2 prev_lambda = prev_vals[0] prev_phi = prev_vals[1] prev_K = prev_vals[2] prev_theta = prev_vals[3] ## not boundary condition # print(qtph[data_idxs[0][1] - 1]) # print(qtprimeph[data_idxs[0][0] - 1]) S, D = self.alg2_VT(qtph[data_idxs[0][1] - 1], qtprimeph[data_idxs[0][0] - 1] ,prev_lambda) new_lambda = self.sw ** 2 * S + self.su ** 2 * xINNER + self.sb ** 2 ## took out an X new_phi = new_lambda + self.sw ** 2 * prev_phi * D # new_phi = prev_phi # new_lambda = prev_lambda #compute kernels S_kernel, D_kernel = self.alg2_VT(qtph[data_idxs[0][1]], qtprimeph[data_idxs[0][0]], new_lambda) new_K = prev_K + self.sv**2 * S_kernel#get current lamnda, get current qtph and qtprimeph new_theta = prev_theta + self.sv**2 * S_kernel + self.sv**2 * D_kernel * new_phi #TODO # ret_K = prev_K # ret_theta = prev_theta equal_check = lambda e: T.equal(idx, e) equal_result = sum(np.vectorize(equal_check)(self.ends_of_calced_diags)) > 0 def true_f(k,t, qp, gph, dataph, dataprimeph, di): xTP_NEXT = dataprimeph[di[1][0]] xT_NEXT = dataph[di[1][1]] xINNER_NEXT = T.inner(xT_NEXT, xTP_NEXT) new_bc = self.make_boundary_condition(xINNER_NEXT) ret_lambda = ret_phi = new_bc S_bc_kernel, D_bc_kernel = self.alg2_VT(qp[di[1][1]], gph[di[1][0]], ret_lambda) ret_K = k + self.sv**2 * S_bc_kernel#get current lamnda, get current qtph and qtprimeph ret_theta = t + self.sv**2 * S_bc_kernel + self.sv**2 * D_bc_kernel * ret_phi #TODO return ret_lambda, ret_phi, ret_K, ret_theta false_f = lambda l,p,k,t: (l,p,k,t) ret_lambda, ret_phi, ret_K, ret_theta = T.cond(equal_result, true_f, false_f, [new_K, new_theta, qtph, qtprimeph, DATAPH, DATAPRIMEPH, data_idxs], [new_lambda, new_phi, new_K, new_theta]) to_carry = create_T_list([ret_lambda, ret_phi, ret_K, ret_theta]) # print('got poast second create list') return to_carry, np.array(()) carry_ema, _ = T.scan( fn, init = initial_conditions, sequences=[jnp.arange(0, sum(self.dim_lengths) - self.dim_num), NEW_DATA_ATTACHED], non_sequences=[T.transpose(self.DATA), T.transpose(self.DATAPRIME), self.qt, self.qtprime] ) return carry_ema[2:4] ## so here, the output will be the added up kernels except for the boundary conditions
import sys sys.path.insert(0, "../") import symjax as sj import symjax.tensor as T import numpy as np __author__ = "Randall Balestriero" # example of cumulative sum def func(carry, x): return carry + 1, 0 output, _ = T.scan(func, T.zeros(1), T.ones(10), length=10) f = sj.function(outputs=output) print(f()) # [10.] # example of simple RNN w = T.Placeholder((3, 10), 'float32') h = T.random.randn((3, 3)) b = T.random.randn((3, )) t_steps = 100 X = T.random.randn((t_steps, 10)) def rnn_cell(carry, x, w):