def create_vae(batch_size, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1): x = T.Placeholder([batch_size, Ds[-1]], 'float32') # ENCODER enc = encoder(x, Ds[0]) mu = enc[-1][:, :Ds[0]] logvar = enc[-1][:, Ds[0]:] var = T.exp(logvar) z = mu + T.exp(0.5 * logvar) * T.random.randn((batch_size, Ds[0])) z_ph = T.Placeholder((batch_size, Ds[0]), 'float32') # DECODER Ws, bs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w) for w in Ws] bs = [T.Variable(b) for b in bs] logvar_x = T.Variable(T.zeros(1), name='logvar_x') var_x = T.exp(logvar_x) h, h_ph = [z], [z_ph] for w, b in zip(Ws[:-1], bs[:-1]): h.append(T.matmul(h[-1], w.transpose()) + b) h.append(h[-1] * relu_mask(h[-1], leakiness)) h_ph.append(T.matmul(h_ph[-1], w.transpose()) + b) h_ph.append(h_ph[-1] * relu_mask(h_ph[-1], leakiness)) h.append(T.matmul(h[-1], Ws[-1].transpose()) + bs[-1]) h_ph.append(T.matmul(h_ph[-1], Ws[-1].transpose()) + bs[-1]) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in bs[:-1]], 0.) / cov_b kl = 0.5 * (1 + logvar - var - mu ** 2).sum(1) px = - 0.5 * (logvar_x + ((x - h[-1])**2 / var_x)).sum(1) loss = - (px + kl).mean() + prior variables = Ws + bs + sj.layers.get_variables(enc) + [logvar_x] opti = sj.optimizers.Adam(loss, lr, params=variables) train = sj.function(x, outputs=loss, updates=opti.updates) g = sj.function(z_ph, outputs=h_ph[-1]) params = sj.function(outputs = Ws + bs + [T.exp(logvar_x) * T.ones(Ds[-1])]) get_varx = sj.function(outputs = var_x) output = {'train': train, 'g':g, 'params':params} output['model'] = 'VAE' output['varx'] = get_varx output['kwargs'] = {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler, 'prior': sj.function(outputs=prior)} def sample(n): samples = [] for i in range(n // batch_size): samples.append(g(np.random.randn(batch_size, Ds[0]))) return np.concatenate(samples) output['sample'] = sample return output
def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5): """Normalizes an array by subtracting mean and dividing by sqrt(var).""" if mean is None: mean = T.mean(x, axis, keepdims=True) if variance is None: # this definition is traditionally seen as less accurate than jnp.var's # mean((x - mean(x))**2) but may be faster and even, given typical # activation distributions and low-precision arithmetic, more accurate # when used in neural network normalization layers variance = T.mean(T.square(x), axis, keepdims=True) - T.square(mean) return (x - mean) * T.rsqrt(variance + epsilon)
def forward(self, input, deterministic=None): if deterministic is None: deterministic = self.deterministic dirac = T.cast(deterministic, 'float32') self.mean = T.mean(input, self.axis, keepdims=True) self.var = T.var(input, self.axis, keepdims=True) if len(self.updates.keys()) == 0: self.avgmean, upm, step = T.ExponentialMovingAverage( self.mean, self.beta1) self.avgvar, upv, step = T.ExponentialMovingAverage( self.var, self.beta2, step=step, init=numpy.ones(self.var.shape).astype('float32')) self.add_variable(self.avgmean) self.add_variable(self.avgvar) self.add_update(upm) self.add_update(upv) self.usemean = self.mean * (1 - dirac) + self.avgmean * dirac self.usevar = self.var * (1 - dirac) + self.avgvar * dirac return self.W * (input - self.usemean) / \ (T.sqrt(self.usevar) + self.const) + self.b
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
def forward( self, input, axis, deterministic, const=1e-4, beta1=0.9, beta2=0.9, W=T.ones, b=T.zeros, trainable_W=True, trainable_b=True, ): self.beta1 = beta1 self.beta2 = beta2 self.const = const self.axis = axis self.deterministic = deterministic parameter_shape = [ input.shape[i] if i in axis else 1 for i in range(input.ndim) ] reduce_axes = [i for i in range(input.ndim) if i not in axis] self.create_variable("W", W, parameter_shape, trainable=trainable_W) self.create_variable("b", b, parameter_shape, trainable=trainable_b) input_mean = T.mean(input, reduce_axes, keepdims=True) input_inv_std = 1 / (T.std(input, reduce_axes, keepdims=True) + const) self.avg_mean = schedules.ExponentialMovingAverage(input_mean, beta1)[1] self.avg_inv_std = schedules.ExponentialMovingAverage( input_inv_std, beta2)[1] use_mean = T.where(deterministic, self.avg_mean, input_mean) use_inv_std = T.where(deterministic, self.avg_inv_std, input_inv_std) W = self.W or 1.0 b = self.b if self.b is not None else 0.0 return W * (input - use_mean) * use_inv_std + b
def fn(window): # the function first input is the current index of the for loop # the other inputs are the (ordered) sequences and non_sequnces # values return T.mean(window)
def create_fns(batch_size, R, Ds, seed, leakiness=0.1, lr=0.0002, scaler=1, var_x=1): alpha = T.Placeholder((1,), 'float32') x = T.Placeholder((Ds[0],), 'float32') X = T.Placeholder((batch_size, Ds[-1]), 'float32') signs = T.Placeholder((np.sum(Ds[1:-1]),), 'float32') SIGNS = T.Placeholder((R, np.sum(Ds[1:-1])), 'float32') m0 = T.Placeholder((batch_size, R), 'float32') m1 = T.Placeholder((batch_size, R, Ds[0]), 'float32') m2 = T.Placeholder((batch_size, R, Ds[0], Ds[0]), 'float32') Ws, vs = init_weights(Ds, seed, scaler) Ws = [T.Variable(w, name='W' + str(l)) for l, w in enumerate(Ws)] vs = [T.Variable(v, name='v' + str(l)) for l, v in enumerate(vs)] var_x = T.Variable(T.ones(Ds[-1]) * var_x) var_z = T.Variable(T.ones(Ds[0])) # create the placeholders Ws_ph = [T.Placeholder(w.shape, w.dtype) for w in Ws] vs_ph = [T.Placeholder(v.shape, v.dtype) for v in vs] var_x_ph = T.Placeholder(var_x.shape, var_x.dtype) ############################################################################ # Compute the output of g(x) ############################################################################ maps = [x] xsigns = [] masks = [] for w, v in zip(Ws[:-1], vs[:-1]): pre_activation = T.matmul(w, maps[-1]) + v xsigns.append(T.sign(pre_activation)) masks.append(relu_mask(pre_activation, leakiness)) maps.append(pre_activation * masks[-1]) xsigns = T.concatenate(xsigns) maps.append(T.matmul(Ws[-1], maps[-1]) + vs[-1]) ############################################################################ # compute the masks and then the per layer affine mappings ############################################################################ cumulative_units = np.cumsum([0] + Ds[1:]) xqs = relu_mask([xsigns[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) qs = relu_mask([signs[None, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Qs = relu_mask([SIGNS[:, cumulative_units[i]:cumulative_units[i + 1]] for i in range(len(Ds) - 2)], leakiness) Axs, bxs = get_Abs(Ws, vs, xqs) Aqs, bqs = get_Abs(Ws, vs, qs) AQs, bQs = get_Abs(Ws, vs, Qs) all_bxs = T.hstack(bxs[:-1]).transpose() all_Axs = T.hstack(Axs[:-1])[0] all_bqs = T.hstack(bqs[:-1]).transpose() all_Aqs = T.hstack(Aqs[:-1])[0] x_inequalities = T.hstack([all_Axs, all_bxs]) * xsigns[:, None] q_inequalities = T.hstack([all_Aqs, all_bqs]) * signs[:, None] ############################################################################ # loss (E-step NLL) ############################################################################ Bm0 = T.einsum('nd,Nn->Nd', bQs[-1], m0) B2m0 = T.einsum('nd,Nn->Nd', bQs[-1] ** 2, m0) Am1 = T.einsum('nds,Nns->Nd', AQs[-1], m1) ABm1 = T.einsum('nds,nd,Nns->Nd', AQs[-1], bQs[-1], m1) Am2ATdiag = T.diagonal(T.einsum('nds,Nnsc,npc->Ndp', AQs[-1], m2, AQs[-1]), axis1=1, axis2=2) xAm1Bm0 = X * (Am1 + Bm0) M2diag = T.diagonal(m2.sum(1), axis1=1, axis2=2) prior = sum([T.mean(w**2) for w in Ws], 0.) / cov_W\ + sum([T.mean(v**2) for v in vs[:-1]], 0.) / cov_b loss = - 0.5 * (T.log(var_x).sum() + T.log(var_z).sum()\ + (M2diag / var_z).sum(1).mean() + ((X ** 2 - 2 * xAm1Bm0 + B2m0\ + Am2ATdiag + 2 * ABm1) / var_x).sum(1).mean()) mean_loss = - (loss + 0.5 * prior) adam = sj.optimizers.SGD(mean_loss, 0.001, params=Ws + vs) ############################################################################ # update of var_x ############################################################################ update_varx = (X ** 2 - 2 * xAm1Bm0 + B2m0 + Am2ATdiag + 2 * ABm1).mean()\ * T.ones(Ds[-1]) update_varz = M2diag.mean() * T.ones(Ds[0]) ############################################################################ # update for biases IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ FQ = get_forward(Ws, Qs) update_vs = {} for i in range(len(vs)): if i < len(vs) - 1: # now we forward each bias to the x-space except the ith separated_bs = bQs[-1] - T.einsum('nds,s->nd', FQ[i], vs[i]) # compute the residual and apply sigma residual = (X[:, None, :] - separated_bs) * m0[:, :, None]\ - T.einsum('nds,Nns->Nnd', AQs[-1], m1) back_error = T.einsum('nds,nd->s', FQ[i], residual.mean(0)) whiten = T.einsum('ndc,nds,n->cs', FQ[i] , FQ[i], m0.mean(0))\ + T.eye(back_error.shape[0]) / (Ds[i] * cov_b) update_vs[vs[i]] = T.linalg.solve(whiten, back_error) else: back_error = (X - (Am1 + Bm0) + vs[-1]) update_vs[vs[i]] = back_error.mean(0) ############################################################################ # update for slopes IT IS DONE FOR ISOTROPIC COVARIANCE MATRIX ############################################################################ update_Ws = {} for i in range(len(Ws)): U = T.einsum('nds,ndc->nsc', FQ[i], FQ[i]) if i == 0: V = m2.mean(0) else: V1 = T.einsum('nd,nq,Nn->ndq', bQs[i-1], bQs[i-1], m0) V2 = T.einsum('nds,nqc,Nnsc->ndq', AQs[i-1], AQs[i-1], m2) V3 = T.einsum('nds,nq,Nns->ndq', AQs[i-1], bQs[i-1], m1) Q = T.einsum('nd,nq->ndq', Qs[i - 1], Qs[i - 1]) V = Q * (V1 + V2 + V3 + V3.transpose((0, 2, 1))) / batch_size whiten = T.stack([T.kron(U[n], V[n]) for n in range(V.shape[0])]).sum(0) whiten = whiten + T.eye(whiten.shape[-1]) / (Ds[i]*Ds[i+1]*cov_W) # compute the residual (bottom up) if i == len(Ws) - 1: bottom_up = (X[:, None, :] - vs[-1]) else: if i == 0: residual = (X[:, None, :] - bQs[-1]) else: residual = (X[:, None, :] - bQs[-1]\ + T.einsum('nds,ns->nd', FQ[i - 1], bQs[i-1])) bottom_up = T.einsum('ndc,Nnd->Nnc', FQ[i], residual) # compute the top down vector if i == 0: top_down = m1 else: top_down = Qs[i - 1] * (T.einsum('nds,Nns->Nnd', AQs[i - 1], m1) +\ T.einsum('nd,Nn->Nnd', bQs[i - 1], m0)) vector = T.einsum('Nnc,Nns->cs', bottom_up, top_down) / batch_size condition = T.diagonal(whiten) update_W = T.linalg.solve(whiten, vector.reshape(-1)).reshape(Ws[i].shape) update_Ws[Ws[i]] = update_W ############################################################################ # create the io functions ############################################################################ params = sj.function(outputs = Ws + vs + [var_x]) ll = T.Placeholder((), 'int32') selector = T.one_hot(ll, len(vs)) for i in range(len(vs)): update_vs[vs[i]] = ((1 - alpha) * vs[i] + alpha * update_vs[vs[i]])\ * selector[i] + vs[i] * (1 - selector[i]) for i in range(len(Ws)): update_Ws[Ws[i]] = ((1 - alpha) * Ws[i] + alpha * update_Ws[Ws[i]])\ * selector[i] + Ws[i] * (1 - selector[i]) output = {'train':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates=adam.updates), 'update_var':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = {var_x: update_varx}), 'update_vs':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_vs), 'loss':sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'update_Ws':sj.function(alpha, ll, SIGNS, X, m0, m1, m2, outputs=mean_loss, updates = update_Ws), 'signs2Ab': sj.function(signs, outputs=[Aqs[-1][0], bqs[-1][0]]), 'signs2ineq': sj.function(signs, outputs=q_inequalities), 'g': sj.function(x, outputs=maps[-1]), 'input2all': sj.function(x, outputs=[maps[-1], Axs[-1][0], bxs[-1][0], x_inequalities, xsigns]), 'get_nll': sj.function(SIGNS, X, m0, m1, m2, outputs=mean_loss), 'assign': sj.function(*Ws_ph, *vs_ph, var_x_ph, updates=dict(zip(Ws + vs + [var_x], Ws_ph + vs_ph + [var_x_ph]))), 'varx': sj.function(outputs=var_x), 'prior': sj.function(outputs=prior), 'varz': sj.function(outputs=var_z), 'params': params, # 'probed' : sj.function(SIGNS, X, m0, m1, m2, outputs=probed), 'input2signs': sj.function(x, outputs=xsigns), 'S' : Ds[0], 'D': Ds[-1], 'R': R, 'model': 'EM', 'L':len(Ds)-1, 'kwargs': {'batch_size': batch_size, 'Ds':Ds, 'seed':seed, 'leakiness':leakiness, 'lr':lr, 'scaler':scaler}} def sample(n): samples = [] for i in range(n): samples.append(output['g'](np.random.randn(Ds[0]))) return np.array(samples) output['sample'] = sample return output
def __init__( self, state_dim, action_dim, lr, gamma, K_epochs, eps_clip, actor, critic, batch_size, continuous=True, ): self.lr = lr self.gamma = gamma self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size state = T.Placeholder((batch_size, ) + state_dim, "float32") reward = T.Placeholder((batch_size, ), "float32") old_action_logprobs = T.Placeholder((batch_size, ), "float32") logits = actor(state) if not continuous: given_action = T.Placeholder((batch_size, ), "int32") dist = Categorical(logits=logits) else: mean = T.tanh(logits[:, :logits.shape[1] // 2]) std = T.exp(logits[:, logits.shape[1] // 2:]) given_action = T.Placeholder((batch_size, action_dim), "float32") dist = MultivariateNormal(mean=mean, diag_std=std) sample = dist.sample() sample_logprobs = dist.log_prob(sample) self._act = symjax.function(state, outputs=[sample, sample_logprobs]) given_action_logprobs = dist.log_prob(given_action) # Finding the ratio (pi_theta / pi_theta__old): ratios = T.exp(sample_logprobs - old_action_logprobs) ratios = T.clip(ratios, None, 1 + self.eps_clip) state_value = critic(state) advantages = reward - T.stop_gradient(state_value) loss = (-T.mean(ratios * advantages) + 0.5 * T.mean( (state_value - reward)**2) - 0.0 * dist.entropy().mean()) print(loss) nn.optimizers.Adam(loss, self.lr) self.learn = symjax.function( state, given_action, reward, old_action_logprobs, outputs=T.mean(loss), updates=symjax.get_updates(), )
def __init__( self, env_fn, actor, critic, gamma=0.99, tau=0.01, lr=1e-3, batch_size=32, epsilon=0.1, epsilon_decay=1 / 1000, min_epsilon=0.01, reward=None, ): # comment out this line if you don't want to record a video of the agent # if save_folder is not None: # test_env = gym.wrappers.Monitor(test_env) # get size of state space and action space num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_max = env.action_space.high[0] else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.gamma = gamma self.continuous = continuous self.observ_min = np.clip(env.observation_space.low, -20, 20) self.observ_max = np.clip(env.observation_space.high, -20, 20) self.env = env self.reward = reward # state state = T.Placeholder((batch_size, num_states), "float32") gradients = T.Placeholder((batch_size, num_actions), "float32") action = T.Placeholder((batch_size, num_actions), "float32") target = T.Placeholder((batch_size, 1), "float32") with symjax.Scope("actor_critic"): scaled_out = action_max * actor(state) Q = critic(state, action) a_loss = -T.sum(gradients * scaled_out) q_loss = T.mean((Q - target)**2) nn.optimizers.Adam(a_loss + q_loss, lr) self.update = symjax.function( state, action, target, gradients, outputs=[a_loss, q_loss], updates=symjax.get_updates(), ) g = symjax.gradients(T.mean(Q), [action])[0] self.get_gradients = symjax.function(state, action, outputs=g) # also create the target variants with symjax.Scope("actor_critic_target"): scaled_out_target = action_max * actor(state) Q_target = critic(state, action) self.actor_predict = symjax.function(state, outputs=scaled_out) self.actor_predict_target = symjax.function(state, outputs=scaled_out_target) self.critic_predict = symjax.function(state, action, outputs=Q) self.critic_predict_target = symjax.function(state, action, outputs=Q_target) t_params = symjax.get_variables(scope="/actor_critic_target/*") params = symjax.get_variables(scope="/actor_critic/*") replacement = { t: tau * e + (1 - tau) * t for t, e in zip(t_params, params) } self.update_target = symjax.function(updates=replacement) single_state = T.Placeholder((1, num_states), "float32") if not continuous: scaled_out = clean_action.argmax(-1) self.act = symjax.function(single_state, outputs=scaled_out.clone( {state: single_state})[0])
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, K_epochs=80, eps_clip=0.2, gamma=0.99, entropy_beta=0.01, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32", name="states") actions = T.Placeholder((batch_size, ) + actions_shape, "float32", name="states") rewards = T.Placeholder((batch_size, ), "float32", name="discounted_rewards") advantages = T.Placeholder((batch_size, ), "float32", name="advantages") self.target_actor = actor(states, distribution="gaussian") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) # Finding the ratio (pi_theta / pi_theta__old) and # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf with symjax.Scope("policy_loss"): ratios = T.exp( self.actor.actions.log_prob(actions) - self.target_actor.actions.log_prob(actions)) ratios = T.clip(ratios, 0, 10) clipped_ratios = T.clip(ratios, 1 - self.eps_clip, 1 + self.eps_clip) surr1 = advantages * ratios surr2 = advantages * clipped_ratios actor_loss = -(T.minimum(surr1, surr2)).mean() with symjax.Scope("monitor"): clipfrac = (((ratios > (1 + self.eps_clip)) | (ratios < (1 - self.eps_clip))).astype("float32").mean()) approx_kl = (self.target_actor.actions.log_prob(actions) - self.actor.actions.log_prob(actions)).mean() with symjax.Scope("critic_loss"): critic_loss = T.mean((rewards - self.critic.q_values)**2) with symjax.Scope("entropy"): entropy = self.actor.actions.entropy().mean() loss = actor_loss + critic_loss # - entropy_beta * entropy with symjax.Scope("optimizer"): nn.optimizers.Adam( loss, lr, params=self.actor.params(True) + self.critic.params(True), ) # create the update function self._train = symjax.function( states, actions, rewards, advantages, outputs=[actor_loss, critic_loss, clipfrac, approx_kl], updates=symjax.get_updates(scope="*optimizer"), ) # initialize target as current self.update_target(1)