def SJ(x, y, N, lr, model, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable(np.random.randn(1, ).astype("float32")) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean() if model == "SGD": optimizers.SGD(sj_loss, lr) elif model == "Adam": optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()) losses = [] for i in tqdm(range(N)): losses.append(train(x, y)) return losses
def SJ(x, y, N, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable( np.random.randn( 1, ).astype("float32") ) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean() optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, updates=symjax.get_updates()) if preallocate: import jax x = jax.device_put(x) y = jax.device_put(y) t = time.time() for i in range(N): train(x, y) return time.time() - t
def updates(self): return symjax.get_updates(scope=self._scope_name) if hasattr(self, "_updates"): return self._updates else: self._updates = {} return self._updates
def test_bn(): sj.current_graph().reset() BATCH_SIZE = 5 DIM = 2 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((1,), "bool", name="deterministic") bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates()) get_stats = sj.function(input, outputs=bn.avg_mean) data = np.random.randn(50, DIM) * 4 + 2 true_means = [] actual_means = [] for i in range(10): batch = data[BATCH_SIZE * i : BATCH_SIZE * (i + 1)] output = update(batch, 0) assert np.allclose( output, (batch - batch.mean(0)) / (1e-4 + batch.std(0)), 1e-4 ) actual_means.append(get_stats(batch)) if i == 0: true_means.append(batch.mean(0)) else: true_means.append(0.9 * true_means[-1] + 0.1 * batch.mean(0)) true_means = np.array(true_means) actual_means = np.array(actual_means).squeeze() assert np.allclose(true_means, actual_means, 1e-4)
def SJ_EMA(X, debias=True): symjax.current_graph().reset() x = T.Placeholder((), "float32", name="x") value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9, debias=debias)[0] train = symjax.function(x, outputs=value, updates=symjax.get_updates()) outputs = [] for i in range(len(X)): outputs.append(train(X[i])) return outputs
def test_sma(): symjax.current_graph().reset() a = symjax.tensor.Placeholder((4, ), "float32") sma, var = symjax.nn.schedules.SimpleMovingAverage(a, 3) f = symjax.function(a, outputs=[sma, var], updates=symjax.get_updates()) data = np.random.randn(4, 4) current = [data[0], data[:2].mean(0), data[:3].mean(0), data[1:4].mean(0)] for i in range(data.shape[0]): out = f(data[i]) assert np.allclose(out[0], current[i])
def test_ema(): symjax.current_graph().reset() a = symjax.tensor.Placeholder((), "float32") ema, var = symjax.nn.schedules.ExponentialMovingAverage(a, 0.9, debias=False) # t = symjax.get_variables("*num_steps*", trainable=False) f = symjax.function(a, outputs=[ema, var], updates=symjax.get_updates()) current = 0 for i in range(10): out = f(1) assert np.allclose(out[1], current) current = 0.9 * current + 0.1 * 1 assert np.allclose(out[0], current)
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, lr=1e-3, gamma=0.99, ): self.actor = actor self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes states = T.Placeholder((self.batch_size, ) + state_shape, "float32") actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size, ), "float32") self.actor = actor(states, distribution="gaussian") logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes with symjax.Scope("REINFORCE_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) # create the update function self._train = symjax.function( states, actions, discounted_rewards, outputs=actor_loss, updates=symjax.get_updates(scope="*/REINFORCE_optimizer"), )
def test_bn(): np.random.seed(0) sj.current_graph().reset() BATCH_SIZE = 5 DIM = 2 input = T.Placeholder((BATCH_SIZE, DIM), "float32", name="input") deterministic = T.Placeholder((), "bool", name="deterministic") bn = nn.layers.BatchNormalization(input, [1], deterministic=deterministic) update = sj.function(input, deterministic, outputs=bn, updates=sj.get_updates()) avmean = symjax.get_variables(trainable=None)[-3] print(avmean) get_stats = sj.function(input, outputs=avmean[0]) data = np.random.randn(50, DIM) * 4 + 2 true_means = [np.zeros(DIM)] actual_means = [np.zeros(DIM)] for i in range(10): batch = data[BATCH_SIZE * i:BATCH_SIZE * (i + 1)] output = update(batch, 0) assert np.allclose( output, (batch - batch.mean(0)) / np.sqrt(0.001 + batch.var(0)), 1e-4, ) actual_means.append(get_stats(batch)) true_means.append(0.99 * true_means[-1] + 0.01 * batch.mean(0)) true_means = np.array(true_means) actual_means = np.array(actual_means).squeeze() assert np.allclose(true_means, actual_means)
import symjax from symjax import nn import symjax.tensor as T import numpy as np input = T.Placeholder((64, 3, 32, 32), "float32") label = T.Placeholder((64, ), "int32") deterministic = T.Placeholder((), "bool") block = nn.models.ResidualBlockv1(activation=nn.relu, dropout_rate=0.1) transformed = nn.models.ResNet( input, widths=[32, 32, 64, 64, 128, 128], strides=[1, 1, 2, 1, 2, 1], block=block, deterministic=deterministic, ) classifier = nn.layers.Dense(transformed, 10) loss = nn.losses.sparse_softmax_crossentropy_logits(label, classifier).mean() nn.optimizers.Adam(loss, 0.001) train = symjax.function(input, label, deterministic, outputs=loss, updates=symjax.get_updates())
def classif_sj(train_x, train_y, test_x, test_y, mlp=True): symjax.current_graph().reset() from symjax import nn batch_size = 128 input = T.Placeholder((batch_size, 3, 32, 32), "float32") labels = T.Placeholder((batch_size, ), "int32") deterministic = T.Placeholder((), "bool") if not mlp: out = nn.relu(nn.layers.Conv2D(input, 32, (3, 3))) for i in range(3): for j in range(3): conv = nn.layers.Conv2D(out, 32 * (i + 1), (3, 3), pad="SAME") bn = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic) bn = nn.relu(bn) conv = nn.layers.Conv2D(bn, 32 * (i + 1), (3, 3), pad="SAME") bn = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic) out = out + bn out = nn.layers.Pool2D(out, (2, 2), pool_type="AVG") out = nn.layers.Conv2D(out, 32 * (i + 2), (1, 1)) # out = out.mean((2, 3)) out = nn.layers.Pool2D(out, out.shape.get()[-2:], pool_type="AVG") else: out = input for i in range(6): out = nn.layers.Dense(out, 4000) out = nn.relu( nn.layers.BatchNormalization(out, [1], deterministic=deterministic)) outputs = nn.layers.Dense(out, 10) loss = nn.losses.sparse_softmax_crossentropy_logits(labels, outputs).mean() nn.optimizers.Adam(loss, 0.001) accu = T.equal(outputs.argmax(1), labels).astype("float32").mean() train = symjax.function( input, labels, deterministic, outputs=[loss, accu, outputs], updates=symjax.get_updates(), ) test = symjax.function(input, labels, deterministic, outputs=accu) for epoch in range(5): accu = 0 for x, y in symjax.data.utils.batchify(train_x, train_y, batch_size=batch_size, option="random"): accu += train(x, y, 0)[1] print("training", accu / (len(train_x) // batch_size)) accu = 0 for x, y in symjax.data.utils.batchify(test_x, test_y, batch_size=batch_size, option="continuous"): accu += test(x, y, 1) print(accu / (len(test_x) // batch_size))
def __init__( self, state_dim, action_dim, lr, gamma, K_epochs, eps_clip, actor, critic, batch_size, continuous=True, ): self.lr = lr self.gamma = gamma self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size state = T.Placeholder((batch_size, ) + state_dim, "float32") reward = T.Placeholder((batch_size, ), "float32") old_action_logprobs = T.Placeholder((batch_size, ), "float32") logits = actor(state) if not continuous: given_action = T.Placeholder((batch_size, ), "int32") dist = Categorical(logits=logits) else: mean = T.tanh(logits[:, :logits.shape[1] // 2]) std = T.exp(logits[:, logits.shape[1] // 2:]) given_action = T.Placeholder((batch_size, action_dim), "float32") dist = MultivariateNormal(mean=mean, diag_std=std) sample = dist.sample() sample_logprobs = dist.log_prob(sample) self._act = symjax.function(state, outputs=[sample, sample_logprobs]) given_action_logprobs = dist.log_prob(given_action) # Finding the ratio (pi_theta / pi_theta__old): ratios = T.exp(sample_logprobs - old_action_logprobs) ratios = T.clip(ratios, None, 1 + self.eps_clip) state_value = critic(state) advantages = reward - T.stop_gradient(state_value) loss = (-T.mean(ratios * advantages) + 0.5 * T.mean( (state_value - reward)**2) - 0.0 * dist.entropy().mean()) print(loss) nn.optimizers.Adam(loss, self.lr) self.learn = symjax.function( state, given_action, reward, old_action_logprobs, outputs=T.mean(loss), updates=symjax.get_updates(), )
def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, critic, lr=1e-3, gamma=0.99, train_v_iters=10, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes self.train_v_iters = train_v_iters states = T.Placeholder((self.batch_size, ) + state_shape, "float32") actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size, ), "float32") advantages = T.Placeholder((self.batch_size, ), "float32") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * advantages).sum() / n_episodes critic_loss = 0.5 * ( (discounted_rewards - self.critic.q_values)**2).mean() with symjax.Scope("actor_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) with symjax.Scope("critic_optimizer"): nn.optimizers.Adam( critic_loss, lr, params=self.critic.params(True), ) # create the update function self._train_actor = symjax.function( states, actions, advantages, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # create the update function self._train_critic = symjax.function( states, discounted_rewards, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), )
def test_learn_bn(): symjax.current_graph().reset() import tensorflow as tf from tensorflow.keras import layers import symjax.nn as nn np.random.seed(0) batch_size = 128 W = np.random.randn(5, 5, 3, 2) W2 = np.random.randn(2 * 28 * 28, 10) inputs = layers.Input(shape=(3, 32, 32)) out = layers.Permute((2, 3, 1))(inputs) out = layers.Conv2D(2, 5, activation="linear", kernel_initializer=lambda *args, **kwargs: W)(out) out = layers.BatchNormalization(-1)(out) out = layers.Activation("relu")(out) out = layers.Flatten()(out) out = layers.Dense(10, activation="linear", kernel_initializer=lambda *args, **kwargs: W2)(out) model = tf.keras.Model(inputs, out) optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) input = T.Placeholder((batch_size, 3, 32, 32), "float32") label = T.Placeholder((batch_size, ), "int32") deterministic = T.Placeholder((), "bool") conv = nn.layers.Conv2D(input, 2, (5, 5), W=W.transpose((3, 2, 0, 1))) out = nn.layers.BatchNormalization(conv, [1], deterministic=deterministic) out = nn.relu(out) out = nn.layers.Dense(out.transpose((0, 2, 3, 1)), 10, W=W2.T) loss = nn.losses.sparse_softmax_crossentropy_logits(label, out).mean() # V2 = T.Variable(W2) # loss = input.sum() * V2.sum() # out=loss nn.optimizers.SGD(loss, 0.001) f = symjax.function(input, label, deterministic, outputs=[loss, out]) g = symjax.function( input, label, deterministic, outputs=symjax.gradients(loss, symjax.get_variables(trainable=True)), ) train = symjax.function( input, label, deterministic, outputs=loss, updates=symjax.get_updates(), ) for epoch in range(10): # generate some random inputs and labels x = np.random.randn(batch_size, 3, 32, 32) y = np.random.randint(0, 10, size=batch_size) # test predictions during testing mode preds = model(x, training=False) nb = np.isclose(preds, f(x, y, 1)[1], atol=1e-3).mean() print("preds not training", nb) assert nb > 0.8 # test prediction during training mode # now get the gradients with tf.GradientTape() as tape: preds = model(x, training=True) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds)) nb = np.isclose(preds, f(x, y, 0)[1], atol=1e-3).mean() print("preds training", nb) assert nb > 0.8 # test loss function during training losss = f(x, y, 0)[0] print("losses", losss, loss) assert np.abs(losss - loss) < 1e-3 # test the gradients grads = tape.gradient(loss, model.trainable_variables) sj_grads = g(x, y, 0) for tf_g, sj_g in zip(grads, sj_grads): if sj_g.ndim == 4: sj_g = sj_g.transpose((2, 3, 1, 0)) else: sj_g = sj_g.transpose() print(sj_g.shape, tf_g.shape) nb = np.isclose( np.reshape(sj_g, -1), np.reshape(tf_g, -1), atol=1e-3, ).mean() print("grads training", nb) assert nb >= 0.5 optimizer.apply_gradients(zip(grads, model.trainable_variables)) train(x, y, 0)
import matplotlib.pyplot as plt import symjax.tensor as T import symjax import numpy as np w = T.Placeholder((3, ), "float32", name="w") alpha = 0.5 new_value, var = symjax.nn.schedules.ExponentialMovingAverage(w, alpha) train = symjax.function(w, outputs=new_value, updates=symjax.get_updates()) data = np.stack([np.ones(200), np.random.randn(200), np.zeros(200)], 1) cost = list() true_ema = [data[0]] aa = 0.5 for j, i in enumerate(data): cost.append(train(i)) true_ema.append(aa * true_ema[-1] + (1 - aa) * i) cost = np.asarray(cost) true = np.asarray(true_ema)[1:] print("% close values:", 100 * np.mean(np.isclose(cost, true))) plt.subplot(311) plt.plot(data[:, 0]) plt.plot(cost[:, 0]) plt.plot(true[:, 0]) plt.subplot(312) plt.plot(data[:, 1], label="true") plt.plot(cost[:, 1], label="symjax") plt.plot(true[:, 1], label="np")
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, train_pi_iters=80, train_v_iters=80, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.extras = {} state_ph = T.Placeholder((batch_size, num_states), "float32") rew_ph = T.Placeholder((batch_size, ), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if not continuous: pi = Categorical(logits=logits) else: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = pi.sample() # pi actions_log_prob = pi.log_prob(actions) # logp with symjax.Scope("critic"): critic_value = critic(state_ph) # AC objectives diff = rew_ph - critic_value actor_loss = -(actions_log_prob * diff).mean() critic_loss = nn.losses.squared_differences(rew_ph, critic_value).mean() with symjax.Scope("update_pi"): nn.optimizers.Adam( actor_loss, self.lr, params=symjax.get_variables(scope="/actor/"), ) with symjax.Scope("update_v"): nn.optimizers.Adam( critic_loss, self.lr, params=symjax.get_variables(scope="/critic/"), ) self.learn_pi = symjax.function( state_ph, rew_ph, outputs=actor_loss, updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, rew_ph, outputs=critic_loss, updates=symjax.get_updates(scope="/update_v/*"), ) single_state = T.Placeholder((1, num_states), "float32") single_action = actions.clone({state_ph: single_state})[0] single_v = critic_value.clone({state_ph: single_state}) self._act = symjax.function( single_state, outputs=[single_action, single_v], )
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, K_epochs=80, eps_clip=0.2, gamma=0.99, entropy_beta=0.01, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32", name="states") actions = T.Placeholder((batch_size, ) + actions_shape, "float32", name="states") rewards = T.Placeholder((batch_size, ), "float32", name="discounted_rewards") advantages = T.Placeholder((batch_size, ), "float32", name="advantages") self.target_actor = actor(states, distribution="gaussian") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) # Finding the ratio (pi_theta / pi_theta__old) and # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf with symjax.Scope("policy_loss"): ratios = T.exp( self.actor.actions.log_prob(actions) - self.target_actor.actions.log_prob(actions)) ratios = T.clip(ratios, 0, 10) clipped_ratios = T.clip(ratios, 1 - self.eps_clip, 1 + self.eps_clip) surr1 = advantages * ratios surr2 = advantages * clipped_ratios actor_loss = -(T.minimum(surr1, surr2)).mean() with symjax.Scope("monitor"): clipfrac = (((ratios > (1 + self.eps_clip)) | (ratios < (1 - self.eps_clip))).astype("float32").mean()) approx_kl = (self.target_actor.actions.log_prob(actions) - self.actor.actions.log_prob(actions)).mean() with symjax.Scope("critic_loss"): critic_loss = T.mean((rewards - self.critic.q_values)**2) with symjax.Scope("entropy"): entropy = self.actor.actions.entropy().mean() loss = actor_loss + critic_loss # - entropy_beta * entropy with symjax.Scope("optimizer"): nn.optimizers.Adam( loss, lr, params=self.actor.params(True) + self.critic.params(True), ) # create the update function self._train = symjax.function( states, actions, rewards, advantages, outputs=[actor_loss, critic_loss, clipfrac, approx_kl], updates=symjax.get_updates(scope="*optimizer"), ) # initialize target as current self.update_target(1)
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, n=1, clip_ratio=0.2, entcoeff=0.01, target_kl=0.01, train_pi_iters=4, train_v_iters=4, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.clip_ratio = clip_ratio self.target_kl = target_kl self.extras = {"logprob": ()} self.entcoeff = entcoeff state_ph = T.Placeholder((batch_size, num_states), "float32") ret_ph = T.Placeholder((batch_size, ), "float32") adv_ph = T.Placeholder((batch_size, ), "float32") act_ph = T.Placeholder((batch_size, num_actions), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if continuous: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) with symjax.Scope("old_actor"): old_logits = actor(state_ph) if continuous: old_logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) if not continuous: pi = Categorical(logits=logits) else: pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = T.clip(pi.sample(), -2, 2) # pi actor_params = actor_params = symjax.get_variables(scope="/actor/") old_actor_params = actor_params = symjax.get_variables( scope="/old_actor/") self.update_target = symjax.function( updates={o: a for o, a in zip(old_actor_params, actor_params)}) # PPO objectives # pi(a|s) / pi_old(a|s) pi_log_prob = pi.log_prob(act_ph) old_pi_log_prob = pi_log_prob.clone({ logits: old_logits, logstd: old_logstd }) ratio = T.exp(pi_log_prob - old_pi_log_prob) surr1 = ratio * adv_ph surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) pi_loss = -T.minimum(surr1, surr2).mean() # ent_loss = pi.entropy().mean() * self.entcoeff with symjax.Scope("critic"): v = critic(state_ph) # critic loss v_loss = ((ret_ph - v)**2).mean() # Info (useful to watch during learning) # a sample estimate for KL approx_kl = (old_pi_log_prob - pi_log_prob).mean() # a sample estimate for entropy # approx_ent = -logprob_given_actions.mean() # clipped = T.logical_or( # ratio > (1 + clip_ratio), ratio < (1 - clip_ratio) # ) # clipfrac = clipped.astype("float32").mean() with symjax.Scope("update_pi"): print(len(actor_params), "actor parameters") nn.optimizers.Adam( pi_loss, self.lr, params=actor_params, ) with symjax.Scope("update_v"): critic_params = symjax.get_variables(scope="/critic/") print(len(critic_params), "critic parameters") nn.optimizers.Adam( v_loss, self.lr, params=critic_params, ) self.get_params = symjax.function(outputs=critic_params) self.learn_pi = symjax.function( state_ph, act_ph, adv_ph, outputs=[pi_loss, approx_kl], updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, ret_ph, outputs=v_loss, updates=symjax.get_updates(scope="/update_v/"), ) single_state = T.Placeholder((1, num_states), "float32") single_v = v.clone({state_ph: single_state}) single_sample = actions.clone({state_ph: single_state}) self._act = symjax.function(single_state, outputs=single_sample) self._get_v = symjax.function(single_state, outputs=single_v) single_action = T.Placeholder((1, num_actions), "float32")
def __init__( self, env_fn, actor, critic, gamma=0.99, tau=0.01, lr=1e-3, batch_size=32, epsilon=0.1, epsilon_decay=1 / 1000, min_epsilon=0.01, reward=None, ): # comment out this line if you don't want to record a video of the agent # if save_folder is not None: # test_env = gym.wrappers.Monitor(test_env) # get size of state space and action space num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_max = env.action_space.high[0] else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.gamma = gamma self.continuous = continuous self.observ_min = np.clip(env.observation_space.low, -20, 20) self.observ_max = np.clip(env.observation_space.high, -20, 20) self.env = env self.reward = reward # state state = T.Placeholder((batch_size, num_states), "float32") gradients = T.Placeholder((batch_size, num_actions), "float32") action = T.Placeholder((batch_size, num_actions), "float32") target = T.Placeholder((batch_size, 1), "float32") with symjax.Scope("actor_critic"): scaled_out = action_max * actor(state) Q = critic(state, action) a_loss = -T.sum(gradients * scaled_out) q_loss = T.mean((Q - target)**2) nn.optimizers.Adam(a_loss + q_loss, lr) self.update = symjax.function( state, action, target, gradients, outputs=[a_loss, q_loss], updates=symjax.get_updates(), ) g = symjax.gradients(T.mean(Q), [action])[0] self.get_gradients = symjax.function(state, action, outputs=g) # also create the target variants with symjax.Scope("actor_critic_target"): scaled_out_target = action_max * actor(state) Q_target = critic(state, action) self.actor_predict = symjax.function(state, outputs=scaled_out) self.actor_predict_target = symjax.function(state, outputs=scaled_out_target) self.critic_predict = symjax.function(state, action, outputs=Q) self.critic_predict_target = symjax.function(state, action, outputs=Q_target) t_params = symjax.get_variables(scope="/actor_critic_target/*") params = symjax.get_variables(scope="/actor_critic/*") replacement = { t: tau * e + (1 - tau) * t for t, e in zip(t_params, params) } self.update_target = symjax.function(updates=replacement) single_state = T.Placeholder((1, num_states), "float32") if not continuous: scaled_out = clean_action.argmax(-1) self.act = symjax.function(single_state, outputs=scaled_out.clone( {state: single_state})[0])
demonstration on how to compute a gradient and apply a basic gradient update rule to minimize some loss function """ import symjax import symjax.tensor as T import matplotlib.pyplot as plt # GRADIENT DESCENT z = T.Variable(3.0, dtype="float32") loss = (z - 1) ** 2 g_z = symjax.gradients(loss, [z])[0] symjax.current_graph().add_updates({z: z - 0.1 * g_z}) train = symjax.function(outputs=[loss, z], updates=symjax.get_updates()) losses = list() values = list() for i in range(200): if (i + 1) % 50 == 0: symjax.reset_variables("*") a, b = train() losses.append(a) values.append(b) plt.figure() plt.subplot(121) plt.plot(losses, "-x") plt.ylabel("loss")
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, gamma=0.99, tau=0.01, ): self.gamma = gamma self.tau = tau self.lr = lr self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32") actions = T.Placeholder((batch_size, ) + actions_shape, "float32") self.critic = critic(states, actions) self.target_critic = critic(states, actions) # create critic loss targets = T.Placeholder(self.critic.q_values.shape, "float32") critic_loss = ((self.critic.q_values - targets)**2).mean() # create optimizer with symjax.Scope("critic_optimizer"): nn.optimizers.Adam(critic_loss, lr, params=self.critic.params(True)) # create the update function self._train_critic = symjax.function( states, actions, targets, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), ) # now create utility function to get the gradients grad = symjax.gradients(self.critic.q_values.sum(), actions) self._get_critic_gradients = symjax.function(states, actions, outputs=grad) # create actor loss self.actor = actor(states) self.target_actor = actor(states) gradients = T.Placeholder(actions.shape, "float32") actor_loss = -(self.actor.actions * gradients).mean() # create optimizer with symjax.Scope("actor_optimizer"): nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True)) # create the update function self._train_actor = symjax.function( states, gradients, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # initialize both networks as the same self.update_target(1)