def VT(self, M): A = T.diag(M) # GP_old is in R^{n*n} having the output gp kernel # of all pairs of data in the data set B = A * A[:, None] C = T.sqrt(B) # in R^{n*n} D = M / C # this is lamblda in ReLU analyrucal formula E = T.clip(D, -1, 1) # clipping E between -1 and 1 for numerical stability. F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C G = (np.pi - T.arccos(E)) / (2 * np.pi) return F,G
def alg1_VT_dep(self, M): #here i will use M as the previous little q ////// NxN, same value for every row A = T.diag(M) # GP_old is in R^{n*n} having the output gp kernel # of all pairs of data in the data set B = A * A[:, None] C = T.sqrt(B) # in R^{n*n} D = M / C # this is lambda in ReLU analyrucal formula (c in alg) E = T.clip(D, -1, 1) # clipping E between -1 and 1 for numerical stability. F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C G = (np.pi - T.arccos(E)) / (2 * np.pi) return F,G
def RNTK_relu(x, RNTK_old, GP_old, param, output): sw = param["sigmaw"] su = param["sigmau"] sb = param["sigmab"] sv = param["sigmav"] a = T.diag(GP_old) # GP_old is in R^{n*n} having the output gp kernel # of all pairs of data in the data set B = a * a[:, None] C = T.sqrt(B) # in R^{n*n} D = GP_old / C # this is lamblda in ReLU analyrucal formula # clipping E between -1 and 1 for numerical stability. E = T.clip(D, -1, 1) F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E**2)) * C G = (np.pi - T.arccos(E)) / (2 * np.pi) if output: GP_new = sv**2 * F RNTK_new = sv**2.0 * RNTK_old * G + GP_new else: X = x * x[:, None] GP_new = sw**2 * F + (su**2 / m) * X + sb**2 RNTK_new = sw**2.0 * RNTK_old * G + GP_new return RNTK_new, GP_new
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, n=1, clip_ratio=0.2, entcoeff=0.01, target_kl=0.01, train_pi_iters=4, train_v_iters=4, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.clip_ratio = clip_ratio self.target_kl = target_kl self.extras = {"logprob": ()} self.entcoeff = entcoeff state_ph = T.Placeholder((batch_size, num_states), "float32") ret_ph = T.Placeholder((batch_size, ), "float32") adv_ph = T.Placeholder((batch_size, ), "float32") act_ph = T.Placeholder((batch_size, num_actions), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if continuous: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) with symjax.Scope("old_actor"): old_logits = actor(state_ph) if continuous: old_logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) if not continuous: pi = Categorical(logits=logits) else: pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = T.clip(pi.sample(), -2, 2) # pi actor_params = actor_params = symjax.get_variables(scope="/actor/") old_actor_params = actor_params = symjax.get_variables( scope="/old_actor/") self.update_target = symjax.function( updates={o: a for o, a in zip(old_actor_params, actor_params)}) # PPO objectives # pi(a|s) / pi_old(a|s) pi_log_prob = pi.log_prob(act_ph) old_pi_log_prob = pi_log_prob.clone({ logits: old_logits, logstd: old_logstd }) ratio = T.exp(pi_log_prob - old_pi_log_prob) surr1 = ratio * adv_ph surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) pi_loss = -T.minimum(surr1, surr2).mean() # ent_loss = pi.entropy().mean() * self.entcoeff with symjax.Scope("critic"): v = critic(state_ph) # critic loss v_loss = ((ret_ph - v)**2).mean() # Info (useful to watch during learning) # a sample estimate for KL approx_kl = (old_pi_log_prob - pi_log_prob).mean() # a sample estimate for entropy # approx_ent = -logprob_given_actions.mean() # clipped = T.logical_or( # ratio > (1 + clip_ratio), ratio < (1 - clip_ratio) # ) # clipfrac = clipped.astype("float32").mean() with symjax.Scope("update_pi"): print(len(actor_params), "actor parameters") nn.optimizers.Adam( pi_loss, self.lr, params=actor_params, ) with symjax.Scope("update_v"): critic_params = symjax.get_variables(scope="/critic/") print(len(critic_params), "critic parameters") nn.optimizers.Adam( v_loss, self.lr, params=critic_params, ) self.get_params = symjax.function(outputs=critic_params) self.learn_pi = symjax.function( state_ph, act_ph, adv_ph, outputs=[pi_loss, approx_kl], updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, ret_ph, outputs=v_loss, updates=symjax.get_updates(scope="/update_v/"), ) single_state = T.Placeholder((1, num_states), "float32") single_v = v.clone({state_ph: single_state}) single_sample = actions.clone({state_ph: single_state}) self._act = symjax.function(single_state, outputs=single_sample) self._get_v = symjax.function(single_state, outputs=single_v) single_action = T.Placeholder((1, num_actions), "float32")
def __init__( self, state_dim, action_dim, lr, gamma, K_epochs, eps_clip, actor, critic, batch_size, continuous=True, ): self.lr = lr self.gamma = gamma self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size state = T.Placeholder((batch_size, ) + state_dim, "float32") reward = T.Placeholder((batch_size, ), "float32") old_action_logprobs = T.Placeholder((batch_size, ), "float32") logits = actor(state) if not continuous: given_action = T.Placeholder((batch_size, ), "int32") dist = Categorical(logits=logits) else: mean = T.tanh(logits[:, :logits.shape[1] // 2]) std = T.exp(logits[:, logits.shape[1] // 2:]) given_action = T.Placeholder((batch_size, action_dim), "float32") dist = MultivariateNormal(mean=mean, diag_std=std) sample = dist.sample() sample_logprobs = dist.log_prob(sample) self._act = symjax.function(state, outputs=[sample, sample_logprobs]) given_action_logprobs = dist.log_prob(given_action) # Finding the ratio (pi_theta / pi_theta__old): ratios = T.exp(sample_logprobs - old_action_logprobs) ratios = T.clip(ratios, None, 1 + self.eps_clip) state_value = critic(state) advantages = reward - T.stop_gradient(state_value) loss = (-T.mean(ratios * advantages) + 0.5 * T.mean( (state_value - reward)**2) - 0.0 * dist.entropy().mean()) print(loss) nn.optimizers.Adam(loss, self.lr) self.learn = symjax.function( state, given_action, reward, old_action_logprobs, outputs=T.mean(loss), updates=symjax.get_updates(), )
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, K_epochs=80, eps_clip=0.2, gamma=0.99, entropy_beta=0.01, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32", name="states") actions = T.Placeholder((batch_size, ) + actions_shape, "float32", name="states") rewards = T.Placeholder((batch_size, ), "float32", name="discounted_rewards") advantages = T.Placeholder((batch_size, ), "float32", name="advantages") self.target_actor = actor(states, distribution="gaussian") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) # Finding the ratio (pi_theta / pi_theta__old) and # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf with symjax.Scope("policy_loss"): ratios = T.exp( self.actor.actions.log_prob(actions) - self.target_actor.actions.log_prob(actions)) ratios = T.clip(ratios, 0, 10) clipped_ratios = T.clip(ratios, 1 - self.eps_clip, 1 + self.eps_clip) surr1 = advantages * ratios surr2 = advantages * clipped_ratios actor_loss = -(T.minimum(surr1, surr2)).mean() with symjax.Scope("monitor"): clipfrac = (((ratios > (1 + self.eps_clip)) | (ratios < (1 - self.eps_clip))).astype("float32").mean()) approx_kl = (self.target_actor.actions.log_prob(actions) - self.actor.actions.log_prob(actions)).mean() with symjax.Scope("critic_loss"): critic_loss = T.mean((rewards - self.critic.q_values)**2) with symjax.Scope("entropy"): entropy = self.actor.actions.entropy().mean() loss = actor_loss + critic_loss # - entropy_beta * entropy with symjax.Scope("optimizer"): nn.optimizers.Adam( loss, lr, params=self.actor.params(True) + self.critic.params(True), ) # create the update function self._train = symjax.function( states, actions, rewards, advantages, outputs=[actor_loss, critic_loss, clipfrac, approx_kl], updates=symjax.get_updates(scope="*optimizer"), ) # initialize target as current self.update_target(1)