def test_update(): sj.current_graph().reset() w = symjax.tensor.zeros(10) for i in range(10): w = symjax.tensor.index_update(w, i, i) f = symjax.function(outputs=w) assert np.array_equal(f(), np.arange(10)) w2 = symjax.tensor.zeros(10) for i in range(10): w2 = symjax.tensor.index_update(w2, (i, ), i) f = symjax.function(outputs=w2) assert np.array_equal(f(), np.arange(10)) w3 = symjax.tensor.zeros(10) for i in range(10): w3 = symjax.tensor.index_update(w3, symjax.tensor.index[i], i) f = symjax.function(outputs=w3) assert np.array_equal(f(), np.arange(10)) w4 = symjax.tensor.Variable(symjax.tensor.zeros(10)) i = symjax.tensor.Variable(0, dtype="int32") update = symjax.tensor.index_update(w4, i, i) f = symjax.function(updates={w4: update, i: i + 1}) for i in range(10): f() assert np.array_equal(w4.value, np.arange(10))
def test_bn(): sj.current_graph().reset() BATCH_SIZE = 5 DIM = 2 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((1,), "bool", name="deterministic") bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates()) get_stats = sj.function(input, outputs=bn.avg_mean) data = np.random.randn(50, DIM) * 4 + 2 true_means = [] actual_means = [] for i in range(10): batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)] output = update(batch, 0) assert np.allclose( output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4 ) actual_means.append(get_stats(batch)) if i == 0: true_means.append(batch.mean(0)) else: true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0)) true_means = np.array(true_means) actual_means = np.array(actual_means).squeeze() assert np.allclose(true_means, actual_means, 1e-4)
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1): x = T.Placeholder([batch_size, Ds[-1]], 'float32') # ENCODER enc = encoder(x, Ds[0]) mu = enc[-1][:, :Ds[0]] logvar = enc[-1][:, Ds[0]:] var = T.exp(logvar) z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0])) z_ph = T.Placeholder((batch_size, Ds[0]), 'float32') # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] logvar_x = T.Variable(T.zeros(1), name='logvar_x') var_x = T.exp(logvar_x) h, h_ph = [z], [z_ph] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b) h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1]) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1) px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1) loss = - (px + kl).mean() + prior variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x] opti = sj.optimizers.Adam(loss, lr, params=variables) train = sj.function(x, outputs=loss, updates=opti.updates) g = sj.function(z_ph, outputs=h_ph[-1]) params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])]) get_varx = sj.function(outputs = var_x) output = {'train': train, 'g':g, 'params':params} output['model'] = 'VAE' output['varx'] = get_varx output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler, 'prior': sj.function(outputs=prior)} def sample(n): samples = [] for i in range(n // batch_size): samples.append(g(np.random.randn(batch_size, Ds[0]))) return np.concatenate(samples) output['sample'] = sample return output
def test_placeholders(): a = symjax.tensor.ones(1) * 2 x = symjax.tensor.Placeholder((), "int32") f = symjax.function(x, outputs=x * a) y = symjax.tensor.Placeholder((), "int32") g = symjax.function(y, outputs=y * a) assert np.isclose(f(1), 2) assert np.isclose(g(2), 4)
def create_fns(input, in_signs, Ds): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) Ws = [sj.initializers.he((j, i)) for j, i in zip(Ds[1:], Ds[:-1])] bs = [sj.initializers.he((j,)) for j in Ds[1:]] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., 0.1) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., 0.1)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) signs = T.concatenate(signs) ineq_b = T.concatenate(B_w[1:-1]) ineq_A = T.vstack(A_w[1:-1]) inequalities = T.hstack([ineq_b[:, None], ineq_A]) inequalities = inequalities * signs[:, None] / T.linalg.norm(ineq_A, 2, 1, keepdims=True) inequalities_code = T.hstack([T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) inequalities_code = inequalities_code * in_signs[:, None] f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g
def test_stop(): a = symjax.tensor.ones(()) b = a + a**2 g = symjax.gradients(b, [a])[0] f = symjax.function(outputs=g) assert f() == 3 b = a + symjax.tensor.stop_gradient(a**2) g = symjax.gradients(b, [a])[0] f = symjax.function(outputs=g) assert f() == 1
def test_clone_0(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") with sj.Scope("placing"): u = T.Placeholder((), "float32", name="u") value = 2 * w * u c = value.clone({w: u}) f = sj.function(u, outputs=value) g = sj.function(u, outputs=c) assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
def test_seed(): a = T.random.randn((), seed=10) b = T.random.randn(()) c = T.random.randn((), seed=10) f = symjax.function(outputs=[a, b, c]) result1 = f() result2 = f() print(result1) print(result2) assert result1[0] == result1[2] assert result1[0] != result1[1] assert result2[0] == result2[2] assert result2[0] != result1[0] a = T.random.randn((), seed=10) b = T.random.randn(()) c = T.random.randn((), seed=10) f = symjax.function(outputs=[a, b, c]) result12 = f() result22 = f() assert result12[0] == result12[2] assert result12[0] != result12[1] assert result22[0] == result22[2] assert result22[0] != result12[0] assert np.isclose(result1[0], result12[0]) assert np.isclose(result1[2], result12[2]) assert not np.isclose(result1[1], result12[1]) assert np.isclose(result2[0], result22[0]) assert np.isclose(result2[2], result22[2]) assert not np.isclose(result2[1], result22[1]) symjax.current_graph().reset() a = T.random.randn((), seed=10) b = T.random.randn(()) c = T.random.randn((), seed=10) f = symjax.function(outputs=[a, b, c]) result12 = f() result22 = f() assert result12[0] == result12[2] assert result12[0] != result12[1] assert result22[0] == result22[2] assert result22[0] != result12[0] assert np.isclose(result1[0], result12[0]) assert np.isclose(result1[2], result12[2]) assert not np.isclose(result1[1], result12[1]) assert np.isclose(result2[0], result22[0]) assert np.isclose(result2[2], result22[2]) assert not np.isclose(result2[1], result22[1])
def update_target(self, tau=None): if not hasattr(self, "_update_target"): with symjax.Scope("update_target"): targets = [] currents = [] if hasattr(self, "target_actor"): targets += self.target_actor.params(True) currents += self.actor.params(True) if hasattr(self, "target_critic"): targets += self.target_critic.params(True) currents += self.critic.params(True) _tau = T.Placeholder((), "float32") updates = { t: t * (1 - _tau) + a * _tau for t, a in zip(targets, currents) } self._update_target = symjax.function(_tau, updates=updates) if tau is None: if not hasattr(self, "tau"): raise RuntimeError("tau must be specified") tau = tau or self.tau self._update_target(tau)
def SJ(x, y, N, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable( np.random.randn( 1, ).astype("float32") ) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean() optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, updates=symjax.get_updates()) if preallocate: import jax x = jax.device_put(x) y = jax.device_put(y) t = time.time() for i in range(N): train(x, y) return time.time() - t
def SJ(x, y, N, lr, model, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable(np.random.randn(1, ).astype("float32")) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean() if model == "SGD": optimizers.SGD(sj_loss, lr) elif model == "Adam": optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()) losses = [] for i in tqdm(range(N)): losses.append(train(x, y)) return losses
def test_stack(): u = tt.Variable(tt.ones((2, ))) output = tt.stack([u, 2 * u, 3 * u]) f = symjax.function(outputs=output) assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2))) print(f()) print(f())
def test_base(): a = T.ones((10, )) b = a.sum() print(b.get()) print(b.get()) f = symjax.function(outputs=b) [f() for i in range(100)]
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, GLO=False): x = T.Placeholder([batch_size, Ds[-1]], 'float32') z = T.Variable(T.random.randn((batch_size, Ds[0]))) logvar_x = T.Variable(T.ones(1)) # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] h = [z] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) # LOSS prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b if GLO: loss = T.sum((x - h[-1])**2) / batch_size + prior variables = Ws + bs else: loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior variables = Ws + bs prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\ + sum([(w**2).sum() for w in Ws], 0.) / cov_W opti = sj.optimizers.Adam(loss + prior, lr, params=variables) infer = sj.optimizers.Adam(loss, lr, params=[z]) estimate = sj.function(x, outputs=z, updates=infer.updates) train = sj.function(x, outputs=loss, updates=opti.updates) lossf = sj.function(x, outputs=loss) params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)]) output = {'train': train, 'estimate':estimate, 'params':params} output['reset'] = lambda v: z.assign(v) if GLO: output['model'] = 'GLO' else: output['model'] = 'HARD' output['loss'] = lossf output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler} return output
def RNTK_function(self): print(f"N, {self.N}, length, {self.length}") DATA = T.Placeholder((self.N, self.length), 'float32') RNTK,GP = self.RNTK_first(DATA[:,0]) v, _ = T.scan(lambda a,b:self.RNTK_middle(a,b),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = self.RNTK_output(v) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return RNTK_last,RNTK_avg
def RNTK_function(N,length,param): DATA = T.Placeholder((N, length), 'float32') RNTK,GP = RNTK_first(DATA[:,0], param['sigmaw'],param['sigmau'],param['sigmab'],param['sigmah'],param['L'], param['Lf'],param['sigmav']) v, _ = T.scan(lambda a,b:RNTK_middle(a,b,param['sigmaw'],param['sigmau'],param['sigmab'],param['L'], param['Lf'],param['sigmav'] ),sequences=[ T.transpose(DATA[:, 1:]) ], init=T.stack([RNTK,GP])) RNTK_last,RNTK_avg = RNTK_output(v, param['sigmav'],param['L'],param['Lf'],length) f = symjax.function(DATA, outputs= [RNTK_last,RNTK_avg]) return f
def __init__(self, states, actions=None): self.state_shape = states.shape[1:] state = T.Placeholder((1, ) + states.shape[1:], "float32", name="critic_state") if actions: self.action_shape = actions.shape[1:] action = T.Placeholder((1, ) + actions.shape[1:], "float32", name="critic_action") action_shape = action.shape[1:] with symjax.Scope("critic"): q_values = self.create_network(states, actions) if q_values.ndim == 2: assert q_values.shape[1] == 1 q_values = q_values[:, 0] q_value = q_values.clone({states: state, actions: action}) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) inputs = [states, actions] input = [state, action] self.actions = actions self.action = action else: with symjax.Scope("critic"): q_values = self.create_network(states) if q_values.ndim == 2: assert q_values.shape[1] == 1 q_values = q_values[:, 0] q_value = q_values.clone({states: state}) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) inputs = [states] input = [state] self.q_values = q_values self.state = state self.states = states self._get_q_values = symjax.function(*inputs, outputs=q_values) self._get_q_value = symjax.function(*input, outputs=q_value[0])
def test_g(): a = symjax.tensor.ones(()) b = symjax.tensor.Variable(1.0) l = a * b g = symjax.gradients(l, [a])[0] f = symjax.function(outputs=g, updates={b: b + 1.0}) assert f() == 1 assert f() == 2 assert f() == 3
def test_grad(): w = tt.Placeholder((), "float32") v = tt.Variable(1.0, dtype="float32") x = w * v + 2 # symjax.nn.optimizers.Adam(x, 0.001) g = symjax.gradients(x.sum(), [v])[0] f = symjax.function(w, outputs=g) assert f(1) == 1 assert f(10) == 10
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
def test_updating_variables(): sj.current_graph().reset() w1 = symjax.tensor.Variable(1.0, dtype="float32") input = symjax.tensor.Placeholder((), "float32") update = w1 + input + 1 f = symjax.function(input, updates={w1: update}) assert w1.value == 1.0 f(10) assert w1.value == 12.0
def test_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) assert np.array_equal(f(2), np.arange(3)) assert np.array_equal(f(2), np.zeros(3)) assert np.array_equal(f(0), -np.arange(3) * 3)
def SJ_EMA(X, debias=True): symjax.current_graph().reset() x = T.Placeholder((), "float32", name="x") value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9, debias=debias)[0] train = symjax.function(x, outputs=value, updates=symjax.get_updates()) outputs = [] for i in range(len(X)): outputs.append(train(X[i])) return outputs
def create_func_for_diag(self, dim1idx, dim2idx, function=False, jmode=False): diag = self.make_inputs(dim1idx, dim2idx, jmode=jmode) # print('test') ## prev_vals - (2,1) - previous phi and lambda values ## idx - where we are on the diagonal ## d1idx - y value of first dimension diag start ## d2idx - x value of second dimension diag start ## d1ph - max value of first dimension ## d2ph - max value of second dimension bc = self.sh**2 * self.sw**2 * T.eye( self.n, self.n) + (self.su**2) * self.X + self.sb**2 single_boundary_condition = T.expand_dims(bc, axis=0) # single_boundary_condition = T.expand_dims(T.Variable((bc), "float32", "boundary_condition"), axis = 0) boundary_condition = T.concatenate( [single_boundary_condition, single_boundary_condition]) #one for phi and lambda def fn(prev_vals, idx, Xph): ## change - xph must now index the dataset instead of being passed in # tiprime_iter = d1idx + idx # ti_iter = d2idx + idx prev_lambda = prev_vals[0] prev_phi = prev_vals[1] ## not boundary condition S, D = self.VT(prev_lambda) new_lambda = self.sw**2 * S + self.su**2 * Xph + self.sb**2 ## took out an X new_phi = new_lambda + self.sw**2 * prev_phi * D lambda_expanded = T.expand_dims(new_lambda, axis=0) phi_expanded = T.expand_dims(new_phi, axis=0) to_return = T.concatenate([lambda_expanded, phi_expanded]) # jax.lax.cond(to_return.shape == (2,10,10), lambda _: print(f'{idx}, true'), lambda _: print(f'{idx}, false'), operand = None) return to_return, to_return last_ema, all_ema = T.scan(fn, init=boundary_condition, sequences=[diag], non_sequences=[self.X]) expanded_ema = T.concatenate( [T.expand_dims(boundary_condition, axis=0), all_ema]) print(expanded_ema) if function: f = symjax.function(diag, outputs=expanded_ema) return f else: return expanded_ema
def test_vectorize_sgd(): sj.current_graph().reset() x = symjax.tensor.Placeholder((0, 2), "float32") y = symjax.tensor.Placeholder((0, ), "float32") w = symjax.tensor.Variable((1, 1), dtype="float32") loss = ((x.dot(w) - y)**2).mean() g = symjax.gradients(loss, [w])[0] other_g = symjax.gradients(x.dot(w).sum(), [w])[0] f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g}) other_f = symjax.function(x, outputs=other_g) L = [10] for i in range(10): L.append(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1))) assert L[-1] < L[-2] assert np.array_equal(other_f(np.ones((i + 1, 2))), [i + 1.0, i + 1.0])
def test_pc(): a, cpt = symjax.nn.schedules.PiecewiseConstant(0, {4: 1, 8: 2}) f = symjax.function(outputs=a, updates={cpt: cpt + 1}) for i in range(10): value = f() if i < 4: assert np.array_equal(value, 0) elif i < 8: assert np.array_equal(value, 1) else: assert np.array_equal(value, 2)
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
def test_vectorize(): sj.current_graph().reset() x = symjax.tensor.Placeholder((0, 2), "float32") w = symjax.tensor.Variable(1.0, dtype="float32") p = x.sum(1) f = symjax.function(x, outputs=p, updates={w: x.sum()}) assert np.array_equal(f(np.ones((1, 2))), [2.0]) assert w.value == 2.0 assert np.array_equal(f(np.ones((2, 2))), [2.0, 2.0]) assert w.value == 4.0
def test_sma(): symjax.current_graph().reset() a = symjax.tensor.Placeholder((4, ), "float32") sma, var = symjax.nn.schedules.SimpleMovingAverage(a, 3) f = symjax.function(a, outputs=[sma, var], updates=symjax.get_updates()) data = np.random.randn(4, 4) current = [data[0], data[:2].mean(0), data[:3].mean(0), data[1:4].mean(0)] for i in range(data.shape[0]): out = f(data[i]) assert np.allclose(out[0], current[i])
def test_global_pool(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 4096 DIM = 8 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") output = nn.layers.Dense(input, 64) output = nn.layers.Dense(output, output.shape[-1] * 2) output = nn.layers.Dense(output, output.shape[-1] * 2) get = sj.function(input, outputs=output) assert get(np.ones((BATCH_SIZE, DIM))).shape == (BATCH_SIZE, 64 * 4)