def SJ(x, y, N, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable( np.random.randn( 1, ).astype("float32") ) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean() optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, updates=symjax.get_updates()) if preallocate: import jax x = jax.device_put(x) y = jax.device_put(y) t = time.time() for i in range(N): train(x, y) return time.time() - t
def SJ(x, y, N, lr, model, preallocate=False): symjax.current_graph().reset() sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D]) sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1]) np.random.seed(0) sj_W = T.Variable(np.random.randn(D, 1).astype("float32")) sj_b = T.Variable(np.random.randn(1, ).astype("float32")) sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output)**2).mean() if model == "SGD": optimizers.SGD(sj_loss, lr) elif model == "Adam": optimizers.Adam(sj_loss, lr) train = symjax.function(sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()) losses = [] for i in tqdm(range(N)): losses.append(train(x, y)) return losses
def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1): x = T.Placeholder([batch_size, Ds[-1]], 'float32') # ENCODER enc = encoder(x, Ds[0]) mu = enc[-1][:, :Ds[0]] logvar = enc[-1][:, Ds[0]:] var = T.exp(logvar) z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0])) z_ph = T.Placeholder((batch_size, Ds[0]), 'float32') # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] logvar_x = T.Variable(T.zeros(1), name='logvar_x') var_x = T.exp(logvar_x) h, h_ph = [z], [z_ph] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b) h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1]) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1) px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1) loss = - (px + kl).mean() + prior variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x] opti = sj.optimizers.Adam(loss, lr, params=variables) train = sj.function(x, outputs=loss, updates=opti.updates) g = sj.function(z_ph, outputs=h_ph[-1]) params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])]) get_varx = sj.function(outputs = var_x) output = {'train': train, 'g':g, 'params':params} output['model'] = 'VAE' output['varx'] = get_varx output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler, 'prior': sj.function(outputs=prior)} def sample(n): samples = [] for i in range(n // batch_size): samples.append(g(np.random.randn(batch_size, Ds[0]))) return np.concatenate(samples) output['sample'] = sample return output
def generate_learnmorlet_filterbank(N, J, Q): freqs = T.Variable(np.ones(J * Q) * 5) scales = 2**(np.linspace(-0.5, np.log2(2 * np.pi * np.log2(N)), J * Q)) scales = T.Variable(scales) filters = T.signal.morlet(N, s=0.01 + T.abs(scales.reshape((-1, 1))), w=freqs.reshape((-1, 1))) filters_norm = filters / T.linalg.norm(filters, 2, 1, keepdims=True) return T.expand_dims(filters_norm, 1), freqs, scales
def test_stack(): u = tt.Variable(tt.ones((2, ))) output = tt.stack([u, 2 * u, 3 * u]) f = symjax.function(outputs=output) assert np.allclose(f(), (np.arange(3)[:, None] + 1) * np.ones((3, 2))) print(f()) print(f())
def generate_sinc_filterbank(f0, f1, J, N): # get the center frequencies freqs = get_scaled_freqs(f0, f1, J + 1) # make it with difference and make it a variable freqs = np.stack([freqs[:-1], freqs[1:]], 1) freqs[:, 1] -= freqs[:, 0] freqs = T.Variable(freqs, name='c_freq') # parametrize the frequencies f0 = T.abs(freqs[:, 0]) f1 = f0 + T.abs(freqs[:, 1]) # sampled the bandpass filters time = T.linspace(-N // 2, N // 2 - 1, N) time_matrix = time.reshape((-1, 1)) sincs = T.signal.sinc_bandpass(time_matrix, f0, f1) # apodize apod_filters = sincs * T.signal.hanning(N).reshape((-1, 1)) # normalize normed_filters = apod_filters / T.linalg.norm( apod_filters, 2, 0, keepdims=True) filters = T.transpose(T.expand_dims(normed_filters, 1), [2, 1, 0]) return filters, freqs
def create_updates(self, grads_or_loss, learning_rate, momentum, params=None): if isinstance(grads_or_loss, list): assert params if params is None: params = self._get_variables(grads_or_loss) elif type(params) != list: raise RuntimeError("given params should be a list") grads = self._get_grads(grads_or_loss, params) updates = dict() variables = [] for param, grad in zip(params, grads): velocity = tensor.Variable(numpy.zeros(param.shape, dtype=param.dtype), trainable=False) variables.append(velocity) update = param - learning_rate * grad x = momentum * velocity + update - param updates[velocity] = x updates[param] = momentum * x + update self.add_updates(updates)
def create_updates(self, grads_or_loss, learning_rate, momentum, params=None): if params is None: params = [ v for k, v in get_graph().variables.items() if v.trainable ] grads = self._get_grads(grads_or_loss, params) if not numpy.isscalar(learning_rate) and not isinstance( learning_rate, tensor.Placeholder): learning_rate = learning_rate() updates = dict() variables = [] for param, grad in zip(params, grads): velocity = tensor.Variable(numpy.zeros(param.shape, dtype=param.dtype), trainable=False) variables.append(velocity) update = param - learning_rate * grad x = momentum * velocity + update - param updates[velocity] = x updates[param] = momentum * x + update self.add_updates(updates)
def create_variable( name, tensor_or_func, shape, trainable, inplace=False, dtype="float32", preprocessor=None, ): if tensor_or_func is None: return None if inplace: assert not callable(tensor_or_func) return tensor_or_func variable = T.Variable( tensor_or_func, name=name, shape=symjax.current_graph().get(shape), dtype=dtype, trainable=trainable, ) if preprocessor is not None: return preprocessor(variable) else: return variable
def create_glo(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, GLO=False): x = T.Placeholder([batch_size, Ds[-1]], 'float32') z = T.Variable(T.random.randn((batch_size, Ds[0]))) logvar_x = T.Variable(T.ones(1)) # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] h = [z] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) # LOSS prior = sum([T.sum(w**2) for w in Ws], 0.) / cov_W + sum([T.sum(v**2) for v in bs[:-1]], 0.) / cov_b if GLO: loss = T.sum((x - h[-1])**2) / batch_size + prior variables = Ws + bs else: loss = Ds[-1] * logvar_x.sum() + T.sum((x - h[-1])**2 / T.exp(logvar_x)) / batch_size + (z**2).sum() / batch_size + prior variables = Ws + bs prior = sum([(b**2).sum() for b in bs], 0.) / cov_b\ + sum([(w**2).sum() for w in Ws], 0.) / cov_W opti = sj.optimizers.Adam(loss + prior, lr, params=variables) infer = sj.optimizers.Adam(loss, lr, params=[z]) estimate = sj.function(x, outputs=z, updates=infer.updates) train = sj.function(x, outputs=loss, updates=opti.updates) lossf = sj.function(x, outputs=loss) params = sj.function(outputs = Ws + bs + [T.ones(Ds[-1]) * T.exp(logvar_x)]) output = {'train': train, 'estimate':estimate, 'params':params} output['reset'] = lambda v: z.assign(v) if GLO: output['model'] = 'GLO' else: output['model'] = 'HARD' output['loss'] = lossf output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler} return output
def create_updates( self, grads_or_loss, learning_rate, amsgrad=False, beta_1=0.9, beta_2=0.999, epsilon=1e-7, params=None, ): if isinstance(grads_or_loss, list): assert params if params is None: params = self._get_variables(grads_or_loss) elif type(params) != list: raise RuntimeError("given params should be a list") if len(params) == 0: raise RuntimeError( "no parameters are given for the gradients, this can be due to passing explicitly an empty list or to passing a lost connected to no trainable weights" ) grads = self._get_grads(grads_or_loss, params) local_step = tensor.Variable(1, dtype="int32", trainable=False) updates = {local_step: local_step + 1} beta_1_t = tensor.power(beta_1, local_step) beta_2_t = tensor.power(beta_2, local_step) lr = learning_rate * (tensor.sqrt(1 - beta_2_t) / (1 - beta_1_t)) for param, grad in zip(params, grads): m = ExponentialMovingAverage(grad, beta_1, debias=False)[0] v = ExponentialMovingAverage(grad**2, beta_2, debias=False)[0] if amsgrad: v_hat = tensor.Variable(tensor.zeros_like(param), name="v_hat", trainable=False) updates[v_hat] = tensor.maximum(v_hat, v) update = m / (tensor.sqrt(updates[v_hat]) + epsilon) else: update = m / (tensor.sqrt(v) + epsilon) update = tensor.where(local_step == 1, grad, update) updates[param] = param - lr * update self.add_updates(updates)
def test_grad(): w = tt.Placeholder((), "float32") v = tt.Variable(1.0, dtype="float32") x = w * v + 2 # symjax.nn.optimizers.Adam(x, 0.001) g = symjax.gradients(x.sum(), [v])[0] f = symjax.function(w, outputs=g) assert f(1) == 1 assert f(10) == 10
def generate_gaussian_filterbank(N, M, J, f0, f1, modes=1): # gaussian parameters freqs = get_scaled_freqs(f0, f1, J) freqs *= (J - 1) * 10 if modes > 1: other_modes = np.random.randint(0, J, J * (modes - 1)) freqs = np.concatenate([freqs, freqs[other_modes]]) # crate the average vectors mu_init = np.stack([freqs, 0.1 * np.random.randn(J * modes)], 1) mu = T.Variable(mu_init.astype('float32'), name='mu') # create the covariance matrix cor = T.Variable(0.01 * np.random.randn(J * modes).astype('float32'), name='cor') sigma_init = np.stack([freqs / 6, 1. + 0.01 * np.random.randn(J * modes)], 1) sigma = T.Variable(sigma_init.astype('float32'), name='sigma') # create the mixing coefficients mixing = T.Variable(np.ones((modes, 1, 1)).astype('float32')) # now apply our parametrization coeff = T.stop_gradient(T.sqrt((T.abs(sigma) + 0.1).prod(1))) * 0.95 Id = T.eye(2) cov = Id * T.expand_dims((T.abs(sigma)+0.1),1) +\ T.flip(Id, 0) * (T.tanh(cor) * coeff).reshape((-1, 1, 1)) cov_inv = T.linalg.inv(cov) # get the gaussian filters time = T.linspace(-5, 5, M) freq = T.linspace(0, J * 10, N) x, y = T.meshgrid(time, freq) grid = T.stack([y.flatten(), x.flatten()], 1) centered = grid - T.expand_dims(mu, 1) # asdf gaussian = T.exp(-(T.matmul(centered, cov_inv)**2).sum(-1)) norm = T.linalg.norm(gaussian, 2, 1, keepdims=True) gaussian_2d = T.abs(mixing) * T.reshape(gaussian / norm, (J, modes, N, M)) return gaussian_2d.sum(1, keepdims=True), mu, cor, sigma, mixing
def test_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) assert np.array_equal(f(2), np.arange(3)) assert np.array_equal(f(2), np.zeros(3)) assert np.array_equal(f(0), -np.arange(3) * 3)
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
def ExponentialMovingAverage(value, alpha): with Scope("ExponentialMovingAverage"): first_step = T.Variable(True, trainable=False, name="first_step", dtype="bool") var = T.Variable(T.zeros(value.shape), trainable=False, dtype="float32", name="EMA") new_value = T.where(first_step, value, var * alpha + (1 - alpha) * value) current_graph().add({var: new_value, first_step: False}) return new_value, var
def test_clone_0(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") with sj.Scope("placing"): u = T.Placeholder((), "float32", name="u") value = 2 * w * u c = value.clone({w: u}) f = sj.function(u, outputs=value) g = sj.function(u, outputs=c) assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
def test_while(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") v = T.Placeholder((), "float32") out = T.while_loop( lambda i, u: i[0] + u < 5, lambda i: (i[0] + 1.0, i[0]**2), (w, 1.0), non_sequences_cond=(v, ), ) f = sj.function(v, outputs=out) assert np.array_equal(np.array(f(0)), [5, 16]) assert np.array_equal(f(2), [3, 4])
def test_clone_base(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") w2 = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") uu = T.Placeholder((), "float32", name="uu") aa = T.Placeholder((), "float32") bb = T.Placeholder((), "float32") l = 2 * w * u * w2 g = sj.gradients(l, w) guu = T.clone(l, {u: uu}) guuu = T.clone(l, {w: uu}) f = sj.function(u, outputs=g, updates={w2: w2 + 1}) fuu = sj.function(uu, outputs=guu, updates={w2: w2 + 1}) fuuu = sj.function(u, uu, outputs=guuu, updates={w2: w2 + 1}) # print(f(2)) assert np.array_equal(f(2), 4.0) assert np.array_equal(fuu(1), 4) assert np.array_equal(fuuu(0, 0), 0)
def create_variable(self, name, tensor_or_func, shape, trainable, dtype=None): t = self.create_tensor(tensor_or_func, shape, dtype) if not trainable: self.__dict__[name] = t else: self.__dict__[name] = T.Variable(t, name=name, trainable=True) self.add_variable(self.__dict__[name])
def __call__(self, action, episode): with symjax.Scope("OUProcess"): self.episode = T.Variable(1, "float32", name="episode", trainable=False) self.noise_scale = self.initial_noise_scale * self.noise_decay**episode x = (self.process + self.theta * (self.mean - self.process) * self.dt + self.std_dev * np.sqrt(self.dt) * np.random.normal(size=action.shape)) # Store x into process # Makes next noise dependent on current one self.process = x return action + self.noise_scale * self.process
def PiecewiseConstant(init, steps_and_values): with Scope("PiecewiseConstant"): all_steps = T.stack([0] + list(steps_and_values.keys())) all_values = T.stack([init] + list(steps_and_values.values())) step = T.Variable( T.zeros(1), trainable=False, name="step", dtype="float32", ) value = all_values[(step < all_steps).argmin() - 1] current_graph().add({step: step + 1}) return value
def test_cond5(): sj.current_graph().reset() v = T.ones((10, 10)) * 3 W = T.Variable(1) u = T.Placeholder((), "int32") out = T.cond( u > 0, lambda a, u: a * u[0], lambda a, u: a + u[1], true_inputs=( W, v, ), false_inputs=( W, v, ), ) f = sj.function(u, outputs=out, updates={W: W + 1}) assert np.array_equal(f(1), 3 * np.ones(10)) assert np.array_equal(f(0), 5 * np.ones(10)) assert np.array_equal(f(1), 9 * np.ones(10))
def create_fns(input, in_signs, Ds, x, m0, m1, m2, batch_in_signs, alpha=0.1, sigma=1, sigma_x=1, lr=0.0002): cumulative_units = np.concatenate([[0], np.cumsum(Ds[:-1])]) BS = batch_in_signs.shape[0] Ws = [ T.Variable(sj.initializers.glorot((j, i)) * sigma) for j, i in zip(Ds[1:], Ds[:-1]) ] bs = [T.Variable(sj.initializers.he((j,)) * sigma) for j in Ds[1:-1]]\ + [T.Variable(T.zeros((Ds[-1],)))] A_w = [T.eye(Ds[0])] B_w = [T.zeros(Ds[0])] A_q = [T.eye(Ds[0])] B_q = [T.zeros(Ds[0])] batch_A_q = [T.eye(Ds[0]) * T.ones((BS, 1, 1))] batch_B_q = [T.zeros((BS, Ds[0]))] maps = [input] signs = [] masks = [T.ones(Ds[0])] in_masks = T.where(T.concatenate([T.ones(Ds[0]), in_signs]) > 0, 1., alpha) batch_in_masks = T.where( T.concatenate([T.ones((BS, Ds[0])), batch_in_signs], 1) > 0, 1., alpha) for w, b in zip(Ws[:-1], bs[:-1]): pre_activation = T.matmul(w, maps[-1]) + b signs.append(T.sign(pre_activation)) masks.append(T.where(pre_activation > 0, 1., alpha)) maps.append(pre_activation * masks[-1]) maps.append(T.matmul(Ws[-1], maps[-1]) + bs[-1]) # compute per region A and B for start, end, w, b, m in zip(cumulative_units[:-1], cumulative_units[1:], Ws, bs, masks): A_w.append(T.matmul(w * m, A_w[-1])) B_w.append(T.matmul(w * m, B_w[-1]) + b) A_q.append(T.matmul(w * in_masks[start:end], A_q[-1])) B_q.append(T.matmul(w * in_masks[start:end], B_q[-1]) + b) batch_A_q.append( T.matmul(w * batch_in_masks[:, None, start:end], batch_A_q[-1])) batch_B_q.append((w * batch_in_masks[:, None, start:end]\ * batch_B_q[-1][:, None, :]).sum(2) + b) batch_B_q = batch_B_q[-1] batch_A_q = batch_A_q[-1] signs = T.concatenate(signs) inequalities = T.hstack( [T.concatenate(B_w[1:-1])[:, None], T.vstack(A_w[1:-1])]) * signs[:, None] inequalities_code = T.hstack( [T.concatenate(B_q[1:-1])[:, None], T.vstack(A_q[1:-1])]) * in_signs[:, None] #### loss log_sigma2 = T.Variable(sigma_x) sigma2 = T.exp(log_sigma2) Am1 = T.einsum('qds,nqs->nqd', batch_A_q, m1) Bm0 = T.einsum('qd,nq->nd', batch_B_q, m0) B2m0 = T.einsum('nq,qd->n', m0, batch_B_q**2) AAm2 = T.einsum('qds,qdu,nqup->nsp', batch_A_q, batch_A_q, m2) inner = -(x * (Am1.sum(1) + Bm0)).sum(1) + (Am1 * batch_B_q).sum((1, 2)) loss_2 = (x**2).sum(1) + B2m0 + T.trace(AAm2, axis1=1, axis2=2).squeeze() loss_z = T.trace(m2.sum(1), axis1=1, axis2=2).squeeze() cst = 0.5 * (Ds[0] + Ds[-1]) * T.log(2 * np.pi) loss = cst + 0.5 * Ds[-1] * log_sigma2 + inner / sigma2\ + 0.5 * loss_2 / sigma2 + 0.5 * loss_z mean_loss = loss.mean() adam = sj.optimizers.NesterovMomentum(mean_loss, Ws + bs, lr, 0.9) train_f = sj.function(batch_in_signs, x, m0, m1, m2, outputs=mean_loss, updates=adam.updates) f = sj.function(input, outputs=[maps[-1], A_w[-1], B_w[-1], inequalities, signs]) g = sj.function(in_signs, outputs=[A_q[-1], B_q[-1]]) all_g = sj.function(in_signs, outputs=inequalities_code) h = sj.function(input, outputs=maps[-1]) return f, g, h, all_g, train_f, sigma2
""" Basic gradient descent (and reset) ================================== demonstration on how to compute a gradient and apply a basic gradient update rule to minimize some loss function """ import symjax import symjax.tensor as T import matplotlib.pyplot as plt # GRADIENT DESCENT z = T.Variable(3.0, dtype="float32") loss = (z - 1) ** 2 g_z = symjax.gradients(loss, [z])[0] symjax.current_graph().add_updates({z: z - 0.1 * g_z}) train = symjax.function(outputs=[loss, z], updates=symjax.get_updates()) losses = list() values = list() for i in range(200): if (i + 1) % 50 == 0: symjax.reset_variables("*") a, b = train() losses.append(a) values.append(b) plt.figure()
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, n=1, clip_ratio=0.2, entcoeff=0.01, target_kl=0.01, train_pi_iters=4, train_v_iters=4, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.clip_ratio = clip_ratio self.target_kl = target_kl self.extras = {"logprob": ()} self.entcoeff = entcoeff state_ph = T.Placeholder((batch_size, num_states), "float32") ret_ph = T.Placeholder((batch_size, ), "float32") adv_ph = T.Placeholder((batch_size, ), "float32") act_ph = T.Placeholder((batch_size, num_actions), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if continuous: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) with symjax.Scope("old_actor"): old_logits = actor(state_ph) if continuous: old_logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) if not continuous: pi = Categorical(logits=logits) else: pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = T.clip(pi.sample(), -2, 2) # pi actor_params = actor_params = symjax.get_variables(scope="/actor/") old_actor_params = actor_params = symjax.get_variables( scope="/old_actor/") self.update_target = symjax.function( updates={o: a for o, a in zip(old_actor_params, actor_params)}) # PPO objectives # pi(a|s) / pi_old(a|s) pi_log_prob = pi.log_prob(act_ph) old_pi_log_prob = pi_log_prob.clone({ logits: old_logits, logstd: old_logstd }) ratio = T.exp(pi_log_prob - old_pi_log_prob) surr1 = ratio * adv_ph surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) pi_loss = -T.minimum(surr1, surr2).mean() # ent_loss = pi.entropy().mean() * self.entcoeff with symjax.Scope("critic"): v = critic(state_ph) # critic loss v_loss = ((ret_ph - v)**2).mean() # Info (useful to watch during learning) # a sample estimate for KL approx_kl = (old_pi_log_prob - pi_log_prob).mean() # a sample estimate for entropy # approx_ent = -logprob_given_actions.mean() # clipped = T.logical_or( # ratio > (1 + clip_ratio), ratio < (1 - clip_ratio) # ) # clipfrac = clipped.astype("float32").mean() with symjax.Scope("update_pi"): print(len(actor_params), "actor parameters") nn.optimizers.Adam( pi_loss, self.lr, params=actor_params, ) with symjax.Scope("update_v"): critic_params = symjax.get_variables(scope="/critic/") print(len(critic_params), "critic parameters") nn.optimizers.Adam( v_loss, self.lr, params=critic_params, ) self.get_params = symjax.function(outputs=critic_params) self.learn_pi = symjax.function( state_ph, act_ph, adv_ph, outputs=[pi_loss, approx_kl], updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, ret_ph, outputs=v_loss, updates=symjax.get_updates(scope="/update_v/"), ) single_state = T.Placeholder((1, num_states), "float32") single_v = v.clone({state_ph: single_state}) single_sample = actions.clone({state_ph: single_state}) self._act = symjax.function(single_state, outputs=single_sample) self._get_v = symjax.function(single_state, outputs=single_v) single_action = T.Placeholder((1, num_actions), "float32")
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/" symjax.current_graph().reset() mnist = symjax.data.mnist() # 2d image images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0] images /= images.max() np.random.seed(0) coordinates = T.meshgrid(T.range(28), T.range(28)) coordinates = T.Variable( T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32") ) interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape( (28, 28) ) loss = ((interp - images[1]) ** 2).mean() lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005}) symjax.nn.optimizers.Adam(loss, lr) train = symjax.function(outputs=loss, updates=symjax.get_updates()) rec = symjax.function(outputs=interp) losses = list()
# map xx = T.ones(10) a = T.map(lambda a: a * 2, xx) g = symjax.gradients(a.sum(), xx)[0] f = symjax.function(outputs=[a, g]) # scan xx = T.ones(10) * 2 a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx) g = symjax.gradients(a[1][-1], xx)[0] f = symjax.function(outputs=[a, g]) # scan with updates xx = T.range(5) uu = T.ones((10, 2)) vvar = T.Variable(T.zeros((10, 2))) vv = T.index_add(vvar, 1, 1) a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv]) #a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx) #a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f()) asdf # fori loop b = T.Placeholder((), 'int32') xx = T.ones(1) a = T.fori_loop(0, b, lambda i, x: i * x, xx) f = symjax.function(b, outputs=a) print(f(0), f(1), f(2), f(3))
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, var_x=1): alpha = T.Placeholder((1,), 'float32') x = T.Placeholder((Ds[0],), 'float32') X = T.Placeholder((batch_size, Ds[-1]), 'float32') signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32') SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32') m0 = T.Placeholder((batch_size, R), 'float32') m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32') m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32') Ws, vs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)] vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)] var_x = T.Variable(T.ones(Ds[-1]) * var_x) var_z = T.Variable(T.ones(Ds[0])) # create the placeholders Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws] vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs] var_x_ph = T.Placeholder(var_x.shape, var_x.dtype) ############################################################################ # Compute the output of g(x) ############################################################################ maps = [x] xsigns = [] masks = [] for w, v in zip(Ws[:-1], vs[:-1]): pre_activation = T.matmul(w, maps[-1]) + v xsigns.append(T.sign(pre_activation)) masks.append(relu_mask(pre_activation, leakiness)) maps.append(pre_activation * masks[-1]) xsigns = T.concatenate(xsigns) maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1]) ############################################################################ # compute the masks and then the per layer affine mappings ############################################################################ cumulative_units = np.cumsum([0] + Ds[1:]) xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Axs, bxs = get_Abs(Ws, vs, xqs) Aqs, bqs = get_Abs(Ws, vs, qs) AQs, bQs = get_Abs(Ws, vs, Qs) all_bxs = T.hstack(bxs[:-1]).transpose() all_Axs = T.hstack(Axs[:-1])[0] all_bqs = T.hstack(bqs[:-1]).transpose() all_Aqs = T.hstack(Aqs[:-1])[0] x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None] q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None] ############################################################################ # loss (E-step NLL) ############################################################################ Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0) B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0) Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1) ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1) Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]), axis1=1, axis2=2) xAm1Bm0 = X * (Am1 + Bm0) M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\ + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\ + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean()) mean_loss = - (loss + 0.5 * prior) adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs) ############################################################################ # update of var_x ############################################################################ update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\ * T.ones(Ds[-1]) update_varz = M2diag.mean() * T.ones(Ds[0]) ############################################################################ # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ FQ = get_forward(Ws, Qs) update_vs = {} for i in range(len(vs)): if i < len(vs) - 1: # now we forward each bias to the x-space except the ith separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i]) # compute the residual and apply sigma residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\ - T.einsum('nds,Nns->Nnd', AQs[-1], m1) back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0)) whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\ + T.eye(back_error.shape[0]) / (Ds[i] * cov_b) update_vs[vs[i]] = T.linalg.solve(whiten, back_error) else: back_error = (X - (Am1 + Bm0) + vs[-1]) update_vs[vs[i]] = back_error.mean(0) ############################################################################ # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ update_Ws = {} for i in range(len(Ws)): U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i]) if i == 0: V = m2.mean(0) else: V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0) V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2) V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1) Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1]) V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0) whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W) # compute the residual (bottom up) if i == len(Ws) - 1: bottom_up = (X[:, None, :] - vs[-1]) else: if i == 0: residual = (X[:, None, :] - bQs[-1]) else: residual = (X[:, None, :] - bQs[-1]\ + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1])) bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual) # compute the top down vector if i == 0: top_down = m1 else: top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\ T.einsum('nd,Nn->Nnd', bQs[i - 1], m0)) vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size condition = T.diagonal(whiten) update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape) update_Ws[Ws[i]] = update_W ############################################################################ # create the io functions ############################################################################ params = sj.function(outputs = Ws + vs + [var_x]) ll = T.Placeholder((), 'int32') selector = T.one_hot(ll, len(vs)) for i in range(len(vs)): update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\ * selector[i] + vs[i] * (1 - selector[i]) for i in range(len(Ws)): update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\ * selector[i] + Ws[i] * (1 - selector[i]) output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates=adam.updates), 'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = {var_x: update_varx}), 'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_vs), 'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_Ws), 'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]), 'signs2ineq': sj.function(signs, outputs=q_inequalities), 'g': sj.function(x, outputs=maps[-1]), 'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0], bxs[-1][0], x_inequalities, xsigns]), 'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph, updates=dict(zip(Ws + vs + [var_x], Ws_ph + vs_ph + [var_x_ph]))), 'varx': sj.function(outputs=var_x), 'prior': sj.function(outputs=prior), 'varz': sj.function(outputs=var_z), 'params': params, # 'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed), 'input2signs': sj.function(x, outputs=xsigns), 'S' : Ds[0], 'D': Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1, 'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler}} def sample(n): samples = [] for i in range(n): samples.append(output['g'](np.random.randn(Ds[0]))) return np.array(samples) output['sample'] = sample return output
def ExponentialMovingAverage( value, alpha, init=None, decay_min=False, debias=True, name="ExponentialMovingAverage", ): """exponential moving average of a given value This method allows to obtain an EMA of a given variable (or any Tensor) with internal state automatically upating its values as new samples are observed and the internal updates are applied as part of a fuction At each iteration the new value is given by .. math:: v(0) = value(0) or init v(t) = v(t-1) * alpha + value(t) * (1 - alpha) Args ---- value: Tensor-like the value to use for the EMA alpha: scalar the decay of the EMA init: Tensor-like (same shape as value) optional the initialization of the EMA, if not given uses the value allowing for unbiased estimate decay_min: bool at early stages, clip the decay to avoid erratir behaviors Returns ------- ema: Tensor-like the current (latest) value of the EMA incorporating information of the latest observation of value fixed_ema: Tensor-like the value of the EMA of the previous pass. This is usefull if one wants to keep the estimate of the EMA fixed for new observations, then simply do not apply anymore updates (using a new function) and using this fixed variable during testing (while ema will keep use the latest observed value) Example ------- .. doctest :: >>> import symjax >>> import numpy as np >>> np.random.seed(0) >>> symjax.current_graph().reset() >>> # suppose we want to do an EMA of a vector user-input >>> input = symjax.tensor.Placeholder((2,), 'float32') >>> ema, var = symjax.nn.schedules.ExponentialMovingAverage(input, 0.9) >>> # in the background, symjax automatically records the needed updates >>> print(symjax.get_updates()) {Variable(name=EMA, shape=(2,), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/): Op(name=where, fn=where, shape=(2,), dtype=float32, scope=/ExponentialMovingAverage/), Variable(name=first_step, shape=(), dtype=bool, trainable=False, scope=/ExponentialMovingAverage/): False} >>> # example of use: >>> f = symjax.function(input, outputs=ema, updates=symjax.get_updates()) >>> for i in range(25): ... print(f(np.ones(2) + np.random.randn(2) * 0.3)) [1.5292157 1.1200472] [1.5056562 1.1752692] [1.5111173 1.1284239] [1.4885082 1.1110408] [1.4365609 1.1122546] [1.3972261 1.1446574] [1.3803346 1.1338419] [1.355617 1.1304679] [1.3648777 1.1112664] [1.3377819 1.0745169] [1.227414 1.0866737] [1.2306056 1.0557414] [1.2756376 1.0065362] [1.2494465 1.000267 ] [1.2704852 1.0443211] [1.2480851 1.0512339] [1.196643 0.9866866] [1.1665413 0.9927084] [1.186796 1.029509] [1.1564965 1.017489 ] [1.1093903 0.97313946] [1.0472631 1.0343488] [1.0272473 1.0177717] [0.9869387 1.0393193] [0.93982786 1.029005 ] """ with Scope(name): init = init if init is not None else T.zeros_like(value, detach=True) num_steps = T.Variable(0, trainable=False, name="num_steps", dtype="int32") var = T.Variable(init, trainable=False, dtype="float32", name="EMA") if decay_min: decay = T.minimum(alpha, (1.0 + num_steps) / (10.0 + num_steps)) else: decay = alpha ema = decay * var + (1 - decay) * value var_update = T.where(T.equal(num_steps, 0), init, ema) current_graph().add_updates({var: ema, num_steps: num_steps + 1}) if debias: debiased_ema = ema_debias(ema, init, decay, num_steps + 1) debiased_var = T.Variable( init, trainable=False, dtype="float32", name="debiased_EMA" ) current_graph().add_updates({debiased_var: debiased_ema}) if debias: return debiased_ema, debiased_var else: return ema, var