def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1): x = T.Placeholder([batch_size, Ds[-1]], 'float32') # ENCODER enc = encoder(x, Ds[0]) mu = enc[-1][:, :Ds[0]] logvar = enc[-1][:, Ds[0]:] var = T.exp(logvar) z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0])) z_ph = T.Placeholder((batch_size, Ds[0]), 'float32') # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] logvar_x = T.Variable(T.zeros(1), name='logvar_x') var_x = T.exp(logvar_x) h, h_ph = [z], [z_ph] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b) h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1]) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1) px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1) loss = - (px + kl).mean() + prior variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x] opti = sj.optimizers.Adam(loss, lr, params=variables) train = sj.function(x, outputs=loss, updates=opti.updates) g = sj.function(z_ph, outputs=h_ph[-1]) params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])]) get_varx = sj.function(outputs = var_x) output = {'train': train, 'g':g, 'params':params} output['model'] = 'VAE' output['varx'] = get_varx output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler, 'prior': sj.function(outputs=prior)} def sample(n): samples = [] for i in range(n // batch_size): samples.append(g(np.random.randn(batch_size, Ds[0]))) return np.concatenate(samples) output['sample'] = sample return output
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, GLO=False): x = T.Placeholder([batch_size, Ds[-1]], 'float32') z = T.Variable(T.random.randn((batch_size, Ds[0]))) logvar_x = T.Variable(T.ones(1)) # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] h = [z] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) # LOSS prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b if GLO: loss = T.sum((x - h[-1])**2) / batch_size + prior variables = Ws + bs else: loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior variables = Ws + bs prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\ + sum([(w**2).sum() for w in Ws], 0.) / cov_W opti = sj.optimizers.Adam(loss + prior, lr, params=variables) infer = sj.optimizers.Adam(loss, lr, params=[z]) estimate = sj.function(x, outputs=z, updates=infer.updates) train = sj.function(x, outputs=loss, updates=opti.updates) lossf = sj.function(x, outputs=loss) params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)]) output = {'train': train, 'estimate':estimate, 'params':params} output['reset'] = lambda v: z.assign(v) if GLO: output['model'] = 'GLO' else: output['model'] = 'HARD' output['loss'] = lossf output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler} return output
def softmax(x, axis=-1): r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` such that the elements along :code:`axis` sum to :math:`1`. .. math :: \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Args: axis: the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to :math:`1`. Either an integer or a tuple of integers. """ unnormalized = T.exp(x - T.stop_gradient(x.max(axis, keepdims=True))) return unnormalized / unnormalized.sum(axis, keepdims=True)
def log_softmax(x, axis=-1): r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales elements to the range :math:`[-\infty, 0)`. .. math :: \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right) Args: axis: the axis or axes along which the :code:`log_softmax` should be computed. Either an integer or a tuple of integers. """ shifted = x - T.stop_gradient(x.max(axis, keepdims=True)) return shifted - T.log(T.sum(T.exp(shifted), axis, keepdims=True))
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1): # gaussian parameters freqs = get_scaled_freqs(f0, f1, J) freqs *= (J - 1) * 10 if modes > 1: other_modes = np.random.randint(0, J, J * (modes - 1)) freqs = np.concatenate([freqs, freqs[other_modes]]) # crate the average vectors mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1) mu = T.Variable(mu_init.astype('float32'), name='mu') # create the covariance matrix cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'), name='cor') sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)], 1) sigma = T.Variable(sigma_init.astype('float32'), name='sigma') # create the mixing coefficients mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32')) # now apply our parametrization coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95 Id = T.eye(2) cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\ T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1)) cov_inv = T.linalg.inv(cov) # get the gaussian filters time = T.linspace(-5, 5, M) freq = T.linspace(0, J * 10, N) x, y = T.meshgrid(time, freq) grid = T.stack([y.flatten(), x.flatten()], 1) centered = grid - T.expand_dims(mu, 1) # asdf gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1)) norm = T.linalg.norm(gaussian, 2, 1, keepdims=True) gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M)) return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
def create_fns(input, in_signs, Ds, x, m0, m1, m2, batch_in_signs, alpha=0.1, sigma=1, sigma_x=1, lr=0.0002): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) BS = batch_in_signs.shape[0] Ws = [ T.Variable(sj.initializers.glorot((j, i)) * sigma) for j, i in zip(Ds[1:], Ds[:-1]) ] bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\ + [T.Variable(T.zeros((Ds[-1],)))] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))] batch_B_q = [T.zeros((BS, Ds[0]))] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha) batch_in_masks = T.where( T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., alpha)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) batch_A_q.append( T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1])) batch_B_q.append((w * batch_in_masks[:, None, start:end]\ * batch_B_q[-1][:, None, :]).sum(2) + b) batch_B_q = batch_B_q[-1] batch_A_q = batch_A_q[-1] signs = T.concatenate(signs) inequalities = T.hstack( [T.concatenate(B_w[1:-1])[:, None], T.vstack(A_w[1:-1])]) * signs[:, None] inequalities_code = T.hstack( [T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) * in_signs[:, None] #### loss log_sigma2 = T.Variable(sigma_x) sigma2 = T.exp(log_sigma2) Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1) Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0) B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2) AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2) inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2)) loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze() loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze() cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi) loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\ + 0.5 * loss_2 / sigma2 + 0.5 * loss_z mean_loss = loss.mean() adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9) train_f = sj.function(batch_in_signs, x, m0, m1, m2, outputs=mean_loss, updates=adam.updates) f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g, train_f, sigma2
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, n=1, clip_ratio=0.2, entcoeff=0.01, target_kl=0.01, train_pi_iters=4, train_v_iters=4, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.clip_ratio = clip_ratio self.target_kl = target_kl self.extras = {"logprob": ()} self.entcoeff = entcoeff state_ph = T.Placeholder((batch_size, num_states), "float32") ret_ph = T.Placeholder((batch_size, ), "float32") adv_ph = T.Placeholder((batch_size, ), "float32") act_ph = T.Placeholder((batch_size, num_actions), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if continuous: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) with symjax.Scope("old_actor"): old_logits = actor(state_ph) if continuous: old_logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) if not continuous: pi = Categorical(logits=logits) else: pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = T.clip(pi.sample(), -2, 2) # pi actor_params = actor_params = symjax.get_variables(scope="/actor/") old_actor_params = actor_params = symjax.get_variables( scope="/old_actor/") self.update_target = symjax.function( updates={o: a for o, a in zip(old_actor_params, actor_params)}) # PPO objectives # pi(a|s) / pi_old(a|s) pi_log_prob = pi.log_prob(act_ph) old_pi_log_prob = pi_log_prob.clone({ logits: old_logits, logstd: old_logstd }) ratio = T.exp(pi_log_prob - old_pi_log_prob) surr1 = ratio * adv_ph surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) pi_loss = -T.minimum(surr1, surr2).mean() # ent_loss = pi.entropy().mean() * self.entcoeff with symjax.Scope("critic"): v = critic(state_ph) # critic loss v_loss = ((ret_ph - v)**2).mean() # Info (useful to watch during learning) # a sample estimate for KL approx_kl = (old_pi_log_prob - pi_log_prob).mean() # a sample estimate for entropy # approx_ent = -logprob_given_actions.mean() # clipped = T.logical_or( # ratio > (1 + clip_ratio), ratio < (1 - clip_ratio) # ) # clipfrac = clipped.astype("float32").mean() with symjax.Scope("update_pi"): print(len(actor_params), "actor parameters") nn.optimizers.Adam( pi_loss, self.lr, params=actor_params, ) with symjax.Scope("update_v"): critic_params = symjax.get_variables(scope="/critic/") print(len(critic_params), "critic parameters") nn.optimizers.Adam( v_loss, self.lr, params=critic_params, ) self.get_params = symjax.function(outputs=critic_params) self.learn_pi = symjax.function( state_ph, act_ph, adv_ph, outputs=[pi_loss, approx_kl], updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, ret_ph, outputs=v_loss, updates=symjax.get_updates(scope="/update_v/"), ) single_state = T.Placeholder((1, num_states), "float32") single_v = v.clone({state_ph: single_state}) single_sample = actions.clone({state_ph: single_state}) self._act = symjax.function(single_state, outputs=single_sample) self._get_v = symjax.function(single_state, outputs=single_v) single_action = T.Placeholder((1, num_actions), "float32")
def __init__( self, state_dim, action_dim, lr, gamma, K_epochs, eps_clip, actor, critic, batch_size, continuous=True, ): self.lr = lr self.gamma = gamma self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size state = T.Placeholder((batch_size, ) + state_dim, "float32") reward = T.Placeholder((batch_size, ), "float32") old_action_logprobs = T.Placeholder((batch_size, ), "float32") logits = actor(state) if not continuous: given_action = T.Placeholder((batch_size, ), "int32") dist = Categorical(logits=logits) else: mean = T.tanh(logits[:, :logits.shape[1] // 2]) std = T.exp(logits[:, logits.shape[1] // 2:]) given_action = T.Placeholder((batch_size, action_dim), "float32") dist = MultivariateNormal(mean=mean, diag_std=std) sample = dist.sample() sample_logprobs = dist.log_prob(sample) self._act = symjax.function(state, outputs=[sample, sample_logprobs]) given_action_logprobs = dist.log_prob(given_action) # Finding the ratio (pi_theta / pi_theta__old): ratios = T.exp(sample_logprobs - old_action_logprobs) ratios = T.clip(ratios, None, 1 + self.eps_clip) state_value = critic(state) advantages = reward - T.stop_gradient(state_value) loss = (-T.mean(ratios * advantages) + 0.5 * T.mean( (state_value - reward)**2) - 0.0 * dist.entropy().mean()) print(loss) nn.optimizers.Adam(loss, self.lr) self.learn = symjax.function( state, given_action, reward, old_action_logprobs, outputs=T.mean(loss), updates=symjax.get_updates(), )
import sys sys.path.insert(0, "../") import symjax import symjax.tensor as T # create our variable to be optimized mu = T.Variable(T.random.normal((1,), seed=1)) cost = T.exp(-((mu - 1) ** 2)) lr = symjax.schedules.PiecewiseConstant(0.01, {100: 0.003, 150: 0.001}) opt = symjax.optimizers.Adam(cost, lr, params=[mu]) print(opt.updates) f = symjax.function(outputs=cost, updates=opt.updates) for k in range(4): for i in range(10): print(f()) print("done") for v in opt.variables + [mu]: v.reset() lr.reset() # 0.008471076 # 0.008201109 # 0.007946267 # 0.007705368 # 0.0074773384 # 0.007261208 # 0.0070561105 # 0.006861261 # 0.006675923
sys.path.insert(0, "../") import symjax import symjax.tensor as T import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use("Agg") ###### DERIVATIVE OF GAUSSIAN EXAMPLE t = T.Placeholder((1000, ), "float32") print(t) f = T.meshgrid(t, t) f = T.exp(-(t**2)) u = f.sum() g = symjax.gradients(u, [t]) g2 = symjax.gradients(g[0].sum(), [t]) g3 = symjax.gradients(g2[0].sum(), [t]) dog = symjax.function(t, outputs=[g[0], g2[0], g3[0]]) plt.plot(np.array(dog(np.linspace(-10, 10, 1000))).T) ###### GRADIENT DESCENT z = T.Variable(3.0) loss = z**2 g_z = symjax.gradients(loss, [z]) print(loss, z) train = symjax.function(outputs=[loss, z], updates={z: z - 0.1 * g_z[0]})
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, K_epochs=80, eps_clip=0.2, gamma=0.99, entropy_beta=0.01, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32", name="states") actions = T.Placeholder((batch_size, ) + actions_shape, "float32", name="states") rewards = T.Placeholder((batch_size, ), "float32", name="discounted_rewards") advantages = T.Placeholder((batch_size, ), "float32", name="advantages") self.target_actor = actor(states, distribution="gaussian") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) # Finding the ratio (pi_theta / pi_theta__old) and # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf with symjax.Scope("policy_loss"): ratios = T.exp( self.actor.actions.log_prob(actions) - self.target_actor.actions.log_prob(actions)) ratios = T.clip(ratios, 0, 10) clipped_ratios = T.clip(ratios, 1 - self.eps_clip, 1 + self.eps_clip) surr1 = advantages * ratios surr2 = advantages * clipped_ratios actor_loss = -(T.minimum(surr1, surr2)).mean() with symjax.Scope("monitor"): clipfrac = (((ratios > (1 + self.eps_clip)) | (ratios < (1 - self.eps_clip))).astype("float32").mean()) approx_kl = (self.target_actor.actions.log_prob(actions) - self.actor.actions.log_prob(actions)).mean() with symjax.Scope("critic_loss"): critic_loss = T.mean((rewards - self.critic.q_values)**2) with symjax.Scope("entropy"): entropy = self.actor.actions.entropy().mean() loss = actor_loss + critic_loss # - entropy_beta * entropy with symjax.Scope("optimizer"): nn.optimizers.Adam( loss, lr, params=self.actor.params(True) + self.critic.params(True), ) # create the update function self._train = symjax.function( states, actions, rewards, advantages, outputs=[actor_loss, critic_loss, clipfrac, approx_kl], updates=symjax.get_updates(scope="*optimizer"), ) # initialize target as current self.update_target(1)
def test_pymc(): class RandomVariable(symjax.tensor.Variable): def __init__(self, name, shape, observed): if observed is None: super().__init__(np.zeros(shape), name=name) else: super().__init__(observed, name=name, trainable=False) def logp(self, value): raise NotImplementedError() def random(self, sample_shape): raise NotImplementedError() @property def logpt(self): return self.logp(self) class Normal(RandomVariable): def __init__(self, name, mu, sigma, shape=None, observed=None): self.mu = mu self.sigma = sigma super().__init__(name, shape, observed) def logp(self, value): tau = self.sigma**-2.0 return (-tau * (value - self.mu)**2 + tt.log(tau / np.pi / 2.0)) / 2.0 def random(self, sample_shape): return np.random.randn(sample_shape) * self.sigma + self.mu x = Normal("x", 0, 10.0) s = Normal("s", 0.0, 5.0) y = Normal("y", x, tt.exp(s)) assert symjax.current_graph().get(y) == 0.0 ################# model_logpt = x.logpt + s.logpt + y.logpt f = symjax.function(x, s, y, outputs=model_logpt) normal_loglike = jsp.stats.norm.logpdf def f_(x, s, y): return (normal_loglike(x, 0.0, 10.0) + normal_loglike(s, 0.0, 5.0) + normal_loglike(y, x, jnp.exp(s))) for i in range(10): x_val = np.random.randn() * 10.0 s_val = np.random.randn() * 5.0 y_val = np.random.randn() * 0.1 + x_val np.testing.assert_allclose(f(x_val, s_val, y_val), f_(x_val, s_val, y_val), rtol=1e-06) model_dlogpt = symjax.gradients(model_logpt, [x, s, y]) f_with_grad = symjax.function(x, s, y, outputs=[model_logpt, model_dlogpt]) f_with_grad(x_val, s_val, y_val) grad_fn = jax.grad(f_, argnums=[0, 1, 2]) f_(x_val, s_val, y_val), grad_fn(x_val, s_val, y_val)