def test_stop(): a = symjax.tensor.ones(()) b = a + a**2 g = symjax.gradients(b, [a])[0] f = symjax.function(outputs=g) assert f() == 3 b = a + symjax.tensor.stop_gradient(a**2) g = symjax.gradients(b, [a])[0] f = symjax.function(outputs=g) assert f() == 1
def test_g(): a = symjax.tensor.ones(()) b = symjax.tensor.Variable(1.0) l = a * b g = symjax.gradients(l, [a])[0] f = symjax.function(outputs=g, updates={b: b + 1.0}) assert f() == 1 assert f() == 2 assert f() == 3
def test_grad(): w = tt.Placeholder((), "float32") v = tt.Variable(1.0, dtype="float32") x = w * v + 2 # symjax.nn.optimizers.Adam(x, 0.001) g = symjax.gradients(x.sum(), [v])[0] f = symjax.function(w, outputs=g) assert f(1) == 1 assert f(10) == 10
def test_vectorize_sgd(): sj.current_graph().reset() x = symjax.tensor.Placeholder((0, 2), "float32") y = symjax.tensor.Placeholder((0, ), "float32") w = symjax.tensor.Variable((1, 1), dtype="float32") loss = ((x.dot(w) - y)**2).mean() g = symjax.gradients(loss, [w])[0] other_g = symjax.gradients(x.dot(w).sum(), [w])[0] f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g}) other_f = symjax.function(x, outputs=other_g) L = [10] for i in range(10): L.append(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1))) assert L[-1] < L[-2] assert np.array_equal(other_f(np.ones((i + 1, 2))), [i + 1.0, i + 1.0])
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
def test_clone_base(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") w2 = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") uu = T.Placeholder((), "float32", name="uu") aa = T.Placeholder((), "float32") bb = T.Placeholder((), "float32") l = 2 * w * u * w2 g = sj.gradients(l, w) guu = T.clone(l, {u: uu}) guuu = T.clone(l, {w: uu}) f = sj.function(u, outputs=g, updates={w2: w2 + 1}) fuu = sj.function(uu, outputs=guu, updates={w2: w2 + 1}) fuuu = sj.function(u, uu, outputs=guuu, updates={w2: w2 + 1}) # print(f(2)) assert np.array_equal(f(2), 4.0) assert np.array_equal(fuu(1), 4) assert np.array_equal(fuuu(0, 0), 0)
def __init__( self, env_fn, actor, critic, gamma=0.99, tau=0.01, lr=1e-3, batch_size=32, epsilon=0.1, epsilon_decay=1 / 1000, min_epsilon=0.01, reward=None, ): # comment out this line if you don't want to record a video of the agent # if save_folder is not None: # test_env = gym.wrappers.Monitor(test_env) # get size of state space and action space num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_max = env.action_space.high[0] else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.gamma = gamma self.continuous = continuous self.observ_min = np.clip(env.observation_space.low, -20, 20) self.observ_max = np.clip(env.observation_space.high, -20, 20) self.env = env self.reward = reward # state state = T.Placeholder((batch_size, num_states), "float32") gradients = T.Placeholder((batch_size, num_actions), "float32") action = T.Placeholder((batch_size, num_actions), "float32") target = T.Placeholder((batch_size, 1), "float32") with symjax.Scope("actor_critic"): scaled_out = action_max * actor(state) Q = critic(state, action) a_loss = -T.sum(gradients * scaled_out) q_loss = T.mean((Q - target)**2) nn.optimizers.Adam(a_loss + q_loss, lr) self.update = symjax.function( state, action, target, gradients, outputs=[a_loss, q_loss], updates=symjax.get_updates(), ) g = symjax.gradients(T.mean(Q), [action])[0] self.get_gradients = symjax.function(state, action, outputs=g) # also create the target variants with symjax.Scope("actor_critic_target"): scaled_out_target = action_max * actor(state) Q_target = critic(state, action) self.actor_predict = symjax.function(state, outputs=scaled_out) self.actor_predict_target = symjax.function(state, outputs=scaled_out_target) self.critic_predict = symjax.function(state, action, outputs=Q) self.critic_predict_target = symjax.function(state, action, outputs=Q_target) t_params = symjax.get_variables(scope="/actor_critic_target/*") params = symjax.get_variables(scope="/actor_critic/*") replacement = { t: tau * e + (1 - tau) * t for t, e in zip(t_params, params) } self.update_target = symjax.function(updates=replacement) single_state = T.Placeholder((1, num_states), "float32") if not continuous: scaled_out = clean_action.argmax(-1) self.act = symjax.function(single_state, outputs=scaled_out.clone( {state: single_state})[0])
z = T.Variable(np.random.randn(*SHAPE).astype("float32"), name="z") get_z = symjax.function(outputs=z) for i in range(10): print(get_z()) # print(f_shuffle()[0]) asdasd w = T.Placeholder(SHAPE, "float32", name="w") noise = T.random.uniform(SHAPE, dtype="float32") y = T.cos(symjax.nn.activations.leaky_relu(z, 0.3) + w + noise) cost = T.pool(y, (2, 2)) cost = T.sum(cost) grads = symjax.gradients(cost, [w, z], [1]) print(cost.get({w: np.random.randn(*SHAPE)})) noise.seed = 20 print(cost.get({w: np.random.randn(*SHAPE)})) noise.seed = 40 print(cost.get({w: np.random.randn(*SHAPE)})) updates = {z: z - 0.01 * grads[0]} fn1 = symjax.function(w, outputs=[cost]) fn2 = symjax.function(w, outputs=[cost], updates=updates) print(fn1(np.random.randn(*SHAPE))) print(fn1(np.random.randn(*SHAPE))) cost = list() for i in range(1000):
def test_learn_bn(): symjax.current_graph().reset() import tensorflow as tf from tensorflow.keras import layers import symjax.nn as nn np.random.seed(0) batch_size = 128 W = np.random.randn(5, 5, 3, 2) W2 = np.random.randn(2 * 28 * 28, 10) inputs = layers.Input(shape=(3, 32, 32)) out = layers.Permute((2, 3, 1))(inputs) out = layers.Conv2D(2, 5, activation="linear", kernel_initializer=lambda *args, **kwargs: W)(out) out = layers.BatchNormalization(-1)(out) out = layers.Activation("relu")(out) out = layers.Flatten()(out) out = layers.Dense(10, activation="linear", kernel_initializer=lambda *args, **kwargs: W2)(out) model = tf.keras.Model(inputs, out) optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) input = T.Placeholder((batch_size, 3, 32, 32), "float32") label = T.Placeholder((batch_size, ), "int32") deterministic = T.Placeholder((), "bool") conv = nn.layers.Conv2D(input, 2, (5, 5), W=W.transpose((3, 2, 0, 1))) out = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic) out = nn.relu(out) out = nn.layers.Dense(out.transpose((0, 2, 3, 1)), 10, W=W2.T) loss = nn.losses.sparse_softmax_crossentropy_logits(label, out).mean() # V2 = T.Variable(W2) # loss = input.sum() * V2.sum() # out=loss nn.optimizers.SGD(loss, 0.001) f = symjax.function(input, label, deterministic, outputs=[loss, out]) g = symjax.function( input, label, deterministic, outputs=symjax.gradients(loss, symjax.get_variables(trainable=True)), ) train = symjax.function( input, label, deterministic, outputs=loss, updates=symjax.get_updates(), ) for epoch in range(10): # generate some random inputs and labels x = np.random.randn(batch_size, 3, 32, 32) y = np.random.randint(0, 10, size=batch_size) # test predictions during testing mode preds = model(x, training=False) nb = np.isclose(preds, f(x, y, 1)[1], atol=1e-3).mean() print("preds not training", nb) assert nb > 0.8 # test prediction during training mode # now get the gradients with tf.GradientTape() as tape: preds = model(x, training=True) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds)) nb = np.isclose(preds, f(x, y, 0)[1], atol=1e-3).mean() print("preds training", nb) assert nb > 0.8 # test loss function during training losss = f(x, y, 0)[0] print("losses", losss, loss) assert np.abs(losss - loss) < 1e-3 # test the gradients grads = tape.gradient(loss, model.trainable_variables) sj_grads = g(x, y, 0) for tf_g, sj_g in zip(grads, sj_grads): if sj_g.ndim == 4: sj_g = sj_g.transpose((2, 3, 1, 0)) else: sj_g = sj_g.transpose() print(sj_g.shape, tf_g.shape) nb = np.isclose( np.reshape(sj_g, -1), np.reshape(tf_g, -1), atol=1e-3, ).mean() print("grads training", nb) assert nb >= 0.5 optimizer.apply_gradients(zip(grads, model.trainable_variables)) train(x, y, 0)
print(f(np.ones((1, 2)))) print(w.value) print(f(np.ones((2, 2)))) print(w.value) # [2.] # 2.0 # [2. 2.] # 4.0 x = symjax.tensor.Placeholder((0, 2), "float32") y = symjax.tensor.Placeholder((0,), "float32") w = symjax.tensor.Variable((1, 1), dtype="float32") loss = ((x.dot(w) - y) ** 2).mean() g = symjax.gradients(loss, [w])[0] other_g = symjax.gradients(x.dot(w).sum(), [w])[0] f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g}) other_f = symjax.function(x, outputs=other_g) for i in range(10): print(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1))) print(other_f(np.ones((i + 1, 2)))) # 9.0 # [1. 1.] # 3.2399998 # [2. 2.] # 1.1663998 # [3. 3.] # 0.419904
import symjax import symjax.tensor as T import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use("Agg") ###### DERIVATIVE OF GAUSSIAN EXAMPLE t = T.Placeholder((1000, ), "float32") print(t) f = T.meshgrid(t, t) f = T.exp(-(t**2)) u = f.sum() g = symjax.gradients(u, [t]) g2 = symjax.gradients(g[0].sum(), [t]) g3 = symjax.gradients(g2[0].sum(), [t]) dog = symjax.function(t, outputs=[g[0], g2[0], g3[0]]) plt.plot(np.array(dog(np.linspace(-10, 10, 1000))).T) ###### GRADIENT DESCENT z = T.Variable(3.0) loss = z**2 g_z = symjax.gradients(loss, [z]) print(loss, z) train = symjax.function(outputs=[loss, z], updates={z: z - 0.1 * g_z[0]}) losses = list()
w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) print(f(2)) # [0, 1, 2] print(f(2)) # [0, 0, 0] print(f(0)) # [0, -3, -6] w.reset() out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u]) g = sj.gradients(out.sum(), [w])[0] f = sj.function(u, outputs=g) print(f(0)) # 0 print(f(1)) # 3 out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)]) f = sj.function(outputs=out) print(f()) # [0, 1, 4]
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, gamma=0.99, tau=0.01, ): self.gamma = gamma self.tau = tau self.lr = lr self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32") actions = T.Placeholder((batch_size, ) + actions_shape, "float32") self.critic = critic(states, actions) self.target_critic = critic(states, actions) # create critic loss targets = T.Placeholder(self.critic.q_values.shape, "float32") critic_loss = ((self.critic.q_values - targets)**2).mean() # create optimizer with symjax.Scope("critic_optimizer"): nn.optimizers.Adam(critic_loss, lr, params=self.critic.params(True)) # create the update function self._train_critic = symjax.function( states, actions, targets, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), ) # now create utility function to get the gradients grad = symjax.gradients(self.critic.q_values.sum(), actions) self._get_critic_gradients = symjax.function(states, actions, outputs=grad) # create actor loss self.actor = actor(states) self.target_actor = actor(states) gradients = T.Placeholder(actions.shape, "float32") actor_loss = -(self.actor.actions * gradients).mean() # create optimizer with symjax.Scope("actor_optimizer"): nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True)) # create the update function self._train_actor = symjax.function( states, gradients, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # initialize both networks as the same self.update_target(1)
out1 = randn * value out2 = out1.clone({randn: rand}) f = symjax.function(rand, outputs=out2, updates={value: 2 + value}) for i in range(3): print(f(i)) # 0. # 3. # 10. # we create a simple computational graph var = T.Variable(T.random.randn((16, 8), seed=10)) loss = ((var - T.ones_like(var))**2).sum() g = symjax.gradients(loss, [var]) opt = symjax.optimizers.SGD(loss, 0.01, params=var) f = symjax.function(outputs=loss, updates=opt.updates) for i in range(10): print(f()) # 240.96829 # 231.42595 # 222.26149 # 213.45993 # 205.00691 # 196.88864 # 189.09186 # 181.60382 # 174.41231
def test_pymc(): class RandomVariable(symjax.tensor.Variable): def __init__(self, name, shape, observed): if observed is None: super().__init__(np.zeros(shape), name=name) else: super().__init__(observed, name=name, trainable=False) def logp(self, value): raise NotImplementedError() def random(self, sample_shape): raise NotImplementedError() @property def logpt(self): return self.logp(self) class Normal(RandomVariable): def __init__(self, name, mu, sigma, shape=None, observed=None): self.mu = mu self.sigma = sigma super().__init__(name, shape, observed) def logp(self, value): tau = self.sigma**-2.0 return (-tau * (value - self.mu)**2 + tt.log(tau / np.pi / 2.0)) / 2.0 def random(self, sample_shape): return np.random.randn(sample_shape) * self.sigma + self.mu x = Normal("x", 0, 10.0) s = Normal("s", 0.0, 5.0) y = Normal("y", x, tt.exp(s)) assert symjax.current_graph().get(y) == 0.0 ################# model_logpt = x.logpt + s.logpt + y.logpt f = symjax.function(x, s, y, outputs=model_logpt) normal_loglike = jsp.stats.norm.logpdf def f_(x, s, y): return (normal_loglike(x, 0.0, 10.0) + normal_loglike(s, 0.0, 5.0) + normal_loglike(y, x, jnp.exp(s))) for i in range(10): x_val = np.random.randn() * 10.0 s_val = np.random.randn() * 5.0 y_val = np.random.randn() * 0.1 + x_val np.testing.assert_allclose(f(x_val, s_val, y_val), f_(x_val, s_val, y_val), rtol=1e-06) model_dlogpt = symjax.gradients(model_logpt, [x, s, y]) f_with_grad = symjax.function(x, s, y, outputs=[model_logpt, model_dlogpt]) f_with_grad(x_val, s_val, y_val) grad_fn = jax.grad(f_, argnums=[0, 1, 2]) f_(x_val, s_val, y_val), grad_fn(x_val, s_val, y_val)
import jax import numpy as np import sys sys.path.insert(0, "../") import symjax import symjax.tensor as T # map xx = T.ones(10) a = T.map(lambda a: a * 2, xx) g = symjax.gradients(a.sum(), xx)[0] f = symjax.function(outputs=[a, g]) # scan xx = T.ones(10) * 2 a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx) g = symjax.gradients(a[1][-1], xx)[0] f = symjax.function(outputs=[a, g]) # scan with updates xx = T.range(5) uu = T.ones((10, 2)) vvar = T.Variable(T.zeros((10, 2))) vv = T.index_add(vvar, 1, 1) a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv]) #a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx) #a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f())
import sys sys.path.insert(0, "../") import symjax import symjax.tensor as T # create our variable to be optimized mu = T.Variable(T.random.normal((), seed=1)) # create our cost cost = T.exp(-(mu-1)**2) # get the gradient, notice that it is itself a tensor that can then # be manipulated as well g = symjax.gradients(cost, mu) print(g) # (Tensor: shape=(), dtype=float32) # create the compield function that will compute the cost and apply # the update onto the variable f = symjax.function(outputs=cost, updates={mu:mu-0.2*g}) for i in range(10): print(f()) # 0.008471076 # 0.008201109 # 0.007946267 # 0.007705368 # 0.0074773384 # 0.007261208
Basic gradient descent (and reset) ================================== demonstration on how to compute a gradient and apply a basic gradient update rule to minimize some loss function """ import symjax import symjax.tensor as T import matplotlib.pyplot as plt # GRADIENT DESCENT z = T.Variable(3.0, dtype="float32") loss = (z - 1) ** 2 g_z = symjax.gradients(loss, [z])[0] symjax.current_graph().add_updates({z: z - 0.1 * g_z}) train = symjax.function(outputs=[loss, z], updates=symjax.get_updates()) losses = list() values = list() for i in range(200): if (i + 1) % 50 == 0: symjax.reset_variables("*") a, b = train() losses.append(a) values.append(b) plt.figure()
return carry + 1, 0 output, _ = T.scan(func, T.zeros(1), T.ones(10), length=10) f = sj.function(outputs=output) print(f()) # [10.] # example of simple RNN w = T.Placeholder((3, 10), 'float32') h = T.random.randn((3, 3)) b = T.random.randn((3, )) t_steps = 100 X = T.random.randn((t_steps, 10)) def rnn_cell(carry, x, w): output = T.sigmoid(T.matmul(w, x) + T.matmul(carry, h) + b) return output, output last, hidden = T.scan(rnn_cell, T.zeros(3), X, constants=(w, )) g = sj.gradients(hidden.sum(), w) print(g.get({w: np.ones((3, 10))})) f = sj.function(w, outputs=hidden) print(f(np.ones((3, 10)))) #