def SJ(x, y, N, lr, model, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable(np.random.randn(1, ).astype("float32")) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean() if model == "SGD": optimizers.SGD(sj_loss, lr) elif model == "Adam": optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()) losses = [] for i in tqdm(range(N)): losses.append(train(x, y)) return losses
def SJ(x, y, N, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable( np.random.randn( 1, ).astype("float32") ) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean() optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, updates=symjax.get_updates()) if preallocate: import jax x = jax.device_put(x) y = jax.device_put(y) t = time.time() for i in range(N): train(x, y) return time.time() - t
def test_bn(): sj.current_graph().reset() BATCH_SIZE = 5 DIM = 2 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((1,), "bool", name="deterministic") bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates()) get_stats = sj.function(input, outputs=bn.avg_mean) data = np.random.randn(50, DIM) * 4 + 2 true_means = [] actual_means = [] for i in range(10): batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)] output = update(batch, 0) assert np.allclose( output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4 ) actual_means.append(get_stats(batch)) if i == 0: true_means.append(batch.mean(0)) else: true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0)) true_means = np.array(true_means) actual_means = np.array(actual_means).squeeze() assert np.allclose(true_means, actual_means, 1e-4)
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1): x = T.Placeholder([batch_size, Ds[-1]], 'float32') # ENCODER enc = encoder(x, Ds[0]) mu = enc[-1][:, :Ds[0]] logvar = enc[-1][:, Ds[0]:] var = T.exp(logvar) z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0])) z_ph = T.Placeholder((batch_size, Ds[0]), 'float32') # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] logvar_x = T.Variable(T.zeros(1), name='logvar_x') var_x = T.exp(logvar_x) h, h_ph = [z], [z_ph] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b) h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1]) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1) px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1) loss = - (px + kl).mean() + prior variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x] opti = sj.optimizers.Adam(loss, lr, params=variables) train = sj.function(x, outputs=loss, updates=opti.updates) g = sj.function(z_ph, outputs=h_ph[-1]) params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])]) get_varx = sj.function(outputs = var_x) output = {'train': train, 'g':g, 'params':params} output['model'] = 'VAE' output['varx'] = get_varx output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler, 'prior': sj.function(outputs=prior)} def sample(n): samples = [] for i in range(n // batch_size): samples.append(g(np.random.randn(batch_size, Ds[0]))) return np.concatenate(samples) output['sample'] = sample return output
def update_target(self, tau=None): if not hasattr(self, "_update_target"): with symjax.Scope("update_target"): targets = [] currents = [] if hasattr(self, "target_actor"): targets += self.target_actor.params(True) currents += self.actor.params(True) if hasattr(self, "target_critic"): targets += self.target_critic.params(True) currents += self.critic.params(True) _tau = T.Placeholder((), "float32") updates = { t: t * (1 - _tau) + a * _tau for t, a in zip(targets, currents) } self._update_target = symjax.function(_tau, updates=updates) if tau is None: if not hasattr(self, "tau"): raise RuntimeError("tau must be specified") tau = tau or self.tau self._update_target(tau)
def make_inputs(self, dim1idx, dim2idx, jmode=False): ## revelation - we dont actually need the diagonal number, just the length! This means we no longer need arange # print("dim1idx", dim1idx) test = jnp.min(jnp.array([self.dim_1, self.dim_2 ])) - (dim1idx + dim2idx) # print("test", test) return T.Placeholder((test, ), "int32")
def RNTK_function(N,length,param): DATA = T.Placeholder((N, length), 'float32') RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav']) v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav'] ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return f
def RNTK_function(self): print(f"N, {self.N}, length, {self.length}") DATA = T.Placeholder((self.N, self.length), 'float32') RNTK,GP = self.RNTK_first(DATA[:,0]) v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = self.RNTK_output(v) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return RNTK_last,RNTK_avg
def __init__(self, states, actions=None): self.state_shape = states.shape[1:] state = T.Placeholder((1, ) + states.shape[1:], "float32", name="critic_state") if actions: self.action_shape = actions.shape[1:] action = T.Placeholder((1, ) + actions.shape[1:], "float32", name="critic_action") action_shape = action.shape[1:] with symjax.Scope("critic"): q_values = self.create_network(states, actions) if q_values.ndim == 2: assert q_values.shape[1] == 1 q_values = q_values[:, 0] q_value = q_values.clone({states: state, actions: action}) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) inputs = [states, actions] input = [state, action] self.actions = actions self.action = action else: with symjax.Scope("critic"): q_values = self.create_network(states) if q_values.ndim == 2: assert q_values.shape[1] == 1 q_values = q_values[:, 0] q_value = q_values.clone({states: state}) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) inputs = [states] input = [state] self.q_values = q_values self.state = state self.states = states self._get_q_values = symjax.function(*inputs, outputs=q_values) self._get_q_value = symjax.function(*input, outputs=q_value[0])
def test_grad(): w = tt.Placeholder((), "float32") v = tt.Variable(1.0, dtype="float32") x = w * v + 2 # symjax.nn.optimizers.Adam(x, 0.001) g = symjax.gradients(x.sum(), [v])[0] f = symjax.function(w, outputs=g) assert f(1) == 1 assert f(10) == 10
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
def SJ_EMA(X, debias=True): symjax.current_graph().reset() x = T.Placeholder((), "float32", name="x") value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9, debias=debias)[0] train = symjax.function(x, outputs=value, updates=symjax.get_updates()) outputs = [] for i in range(len(X)): outputs.append(train(X[i])) return outputs
def create_func(dic, printbool = False): N = int(dic["n_patrons1="]) ti_length = int(dic["n_entradasTi="]) ti_prime_length = int(dic["n_entradasTiP="]) DATA = T.Placeholder((N, ti_length), 'float32', name = "X") DATAPRIME = T.Placeholder((N, ti_prime_length), 'float32', name = "X") # x = DATA[:,0] # X = x*x[:, None] # n = X.shape[0] rntkod = RNTK(dic, DATA, DATAPRIME) #could be flipped start = time.time() kernels_ema = rntkod.create_func_for_diag() diag_func = symjax.function(DATA, DATAPRIME, outputs=kernels_ema) if printbool: print("time to create symjax", time.time() - start) return diag_func, rntkod
def test_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) assert np.array_equal(f(2), np.arange(3)) assert np.array_equal(f(2), np.zeros(3)) assert np.array_equal(f(0), -np.arange(3) * 3)
def test_clone_0(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") with sj.Scope("placing"): u = T.Placeholder((), "float32", name="u") value = 2 * w * u c = value.clone({w: u}) f = sj.function(u, outputs=value) g = sj.function(u, outputs=c) assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, lr=1e-3, gamma=0.99, ): self.actor = actor self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes states = T.Placeholder((self.batch_size, ) + state_shape, "float32") actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size, ), "float32") self.actor = actor(states, distribution="gaussian") logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes with symjax.Scope("REINFORCE_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) # create the update function self._train = symjax.function( states, actions, discounted_rewards, outputs=actor_loss, updates=symjax.get_updates(scope="*/REINFORCE_optimizer"), )
def test_global_pool(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 4096 DIM = 8 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") output = nn.layers.Dense(input, 64) output = nn.layers.Dense(output, output.shape[-1] * 2) output = nn.layers.Dense(output, output.shape[-1] * 2) get = sj.function(input, outputs=output) assert get(np.ones((BATCH_SIZE, DIM))).shape == (BATCH_SIZE, 64 * 4)
def test_dropout(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 4096 DIM = 8 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((), "bool", name="deterministic") bn = nn.layers.Dropout(input, p=0.2, deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn) data = np.ones((BATCH_SIZE, DIM)) output1 = update(data, 0) output2 = update(data, 0) output3 = update(data, 1) assert not np.allclose(output1, output2, 1e-1) assert np.allclose(output1.mean(0) / 2 + output2.mean(0) / 2, 1, 0.08) assert np.all(output3)
def test_flip(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 2048 DIM = 8 input = T.Placeholder((BATCH_SIZE, DIM, DIM), "float32", name="input") deterministic = T.Placeholder((1,), "bool", name="deterministic") bn = nn.layers.RandomFlip(input, axis=2, p=0.5, deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn) data = np.ones((BATCH_SIZE, DIM, DIM)) data[:, :, : DIM // 2] = 0 output1 = update(data, 0) output2 = update(data, 0) output3 = update(data, 1) assert not np.allclose(output1, output2, 1e-1) assert np.allclose(output1.mean(0) / 2 + output2.mean(0) / 2, 0.5, 0.05) assert np.allclose(data, output3, 1e-6)
def test_bn(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 5 DIM = 2 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((), "bool", name="deterministic") bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates()) avmean = symjax.get_variables(trainable=None)[-3] print(avmean) get_stats = sj.function(input, outputs=avmean[0]) data = np.random.randn(50, DIM) * 4 + 2 true_means = [np.zeros(DIM)] actual_means = [np.zeros(DIM)] for i in range(10): batch = data[BATCH_SIZE * i:BATCH_SIZE * (i + 1)] output = update(batch, 0) assert np.allclose( output, (batch - batch.mean(0)) / np.sqrt(0.001 + batch.var(0)), 1e-4, ) actual_means.append(get_stats(batch)) true_means.append(0.99 * true_means[-1] + 0.01 * batch.mean(0)) true_means = np.array(true_means) actual_means = np.array(actual_means).squeeze() assert np.allclose(true_means, actual_means)
def test_while(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") v = T.Placeholder((), "float32") out = T.while_loop( lambda i, u: i[0] + u < 5, lambda i: (i[0] + 1.0, i[0]**2), (w, 1.0), non_sequences_cond=(v, ), ) f = sj.function(v, outputs=out) assert np.array_equal(np.array(f(0)), [5, 16]) assert np.array_equal(f(2), [3, 4])
def test_clone_base(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") w2 = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") uu = T.Placeholder((), "float32", name="uu") aa = T.Placeholder((), "float32") bb = T.Placeholder((), "float32") l = 2 * w * u * w2 g = sj.gradients(l, w) guu = T.clone(l, {u: uu}) guuu = T.clone(l, {w: uu}) f = sj.function(u, outputs=g, updates={w2: w2 + 1}) fuu = sj.function(uu, outputs=guu, updates={w2: w2 + 1}) fuuu = sj.function(u, uu, outputs=guuu, updates={w2: w2 + 1}) # print(f(2)) assert np.array_equal(f(2), 4.0) assert np.array_equal(fuu(1), 4) assert np.array_equal(fuuu(0, 0), 0)
def test_cond2(): sj.current_graph().reset() v = T.ones((10, 10)) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda u: 4 * u, lambda u: u, true_inputs=(v, ), false_inputs=(2 * v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 4 * np.ones((10, 10))) assert np.array_equal(f(0), 2 * np.ones((10, 10)))
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, GLO=False): x = T.Placeholder([batch_size, Ds[-1]], 'float32') z = T.Variable(T.random.randn((batch_size, Ds[0]))) logvar_x = T.Variable(T.ones(1)) # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] h = [z] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) # LOSS prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b if GLO: loss = T.sum((x - h[-1])**2) / batch_size + prior variables = Ws + bs else: loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior variables = Ws + bs prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\ + sum([(w**2).sum() for w in Ws], 0.) / cov_W opti = sj.optimizers.Adam(loss + prior, lr, params=variables) infer = sj.optimizers.Adam(loss, lr, params=[z]) estimate = sj.function(x, outputs=z, updates=infer.updates) train = sj.function(x, outputs=loss, updates=opti.updates) lossf = sj.function(x, outputs=loss) params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)]) output = {'train': train, 'estimate':estimate, 'params':params} output['reset'] = lambda v: z.assign(v) if GLO: output['model'] = 'GLO' else: output['model'] = 'HARD' output['loss'] = lossf output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler} return output
def create_func(dic, printbool=False): N = int(dic["n_patrons1="]) length = int(dic["n_entradas="]) DATA = T.Placeholder((N, length), 'float32', name="X") x = DATA[:, 0] X = x * x[:, None] n = X.shape[0] print(n, N) rntkod = RNTK(dic, X, n) #could be flipped start = time.time() lin_ema = rntkod.create_func_for_diag() diag_func = symjax.function(DATA, outputs=lin_ema) if printbool: print("time to create symjax", time.time() - start) return diag_func, rntkod
def test_cond3(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u, lambda a, u: a + u, true_inputs=( 2 * T.ones((10, 10)), v, ), false_inputs=( 2 * T.ones((10, 10)), v, ), ) f = sj.function(u, outputs=out) assert np.array_equal(f(1), 6 * np.ones((10, 10))) assert np.array_equal(f(0), 5 * np.ones((10, 10)))
def __init__(self, states, actions_distribution=None, name="actor"): self.state_shape = states.shape[1:] state = T.Placeholder((1, ) + states.shape[1:], "float32") self.actions_distribution = actions_distribution with symjax.Scope(name): if actions_distribution == symjax.probabilities.Normal: means, covs = self.create_network(states) actions = actions_distribution(means, cov=covs) samples = actions.sample() samples_log_prob = actions.log_prob(samples) action = symjax.probabilities.MultivariateNormal( means.clone({states: state}), cov=covs.clone({states: state}), ) sample = self.action.sample() sample_log_prob = self.action.log_prob(sample) self._get_actions = symjax.function( states, outputs=[samples, samples_log_prob]) self._get_action = symjax.function( state, outputs=[sample[0], sample_log_prob[0]], ) elif actions_distribution is None: actions = self.create_network(states) action = actions.clone({states: state}) self._get_actions = symjax.function(states, outputs=actions) self._get_action = symjax.function(state, outputs=action[0]) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) self.actions = actions self.state = state self.action = action
def test_cond5(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 W = T.Variable(1) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u[0], lambda a, u: a + u[1], true_inputs=( W, v, ), false_inputs=( W, v, ), ) f = sj.function(u, outputs=out, updates={W: W + 1}) assert np.array_equal(f(1), 3 * np.ones(10)) assert np.array_equal(f(0), 5 * np.ones(10)) assert np.array_equal(f(1), 9 * np.ones(10))
We then demonstrate how to do a simple for loop and then a while loop. """ import matplotlib.pyplot as plt import symjax import symjax.tensor as T import numpy as np # suppose we are given a time-serie and we want to compute an # exponential moving average, we also use the EMA coefficient alpha # based on the user input signal = T.Placeholder((512,), "float32", name="signal") alpha = T.Placeholder((), "float32", "alpha") # to use a scan function one needs a function to be applied at each step # in our case an exponential moving average function # this function should output the new value of the carry as well as an # additional output, in our case, the carry (EMA) is also what we want to # output at each tiem step def fn(at, xt, alpha): # the function first input is the carry, then are the (ordered) # values from sequences and non_sequences similar to Theano EMA = at * alpha + (1 - alpha) * xt return EMA, EMA