def __init__(self, *args, name=None, **kwargs): if name is None: name = self.__NAME__ with symjax.Scope(name): self.create_updates(*args, **kwargs) self._scope_name = symjax.current_graph().scope_name
def update_target(self, tau=None): if not hasattr(self, "_update_target"): with symjax.Scope("update_target"): targets = [] currents = [] if hasattr(self, "target_actor"): targets += self.target_actor.params(True) currents += self.actor.params(True) if hasattr(self, "target_critic"): targets += self.target_critic.params(True) currents += self.critic.params(True) _tau = T.Placeholder((), "float32") updates = { t: t * (1 - _tau) + a * _tau for t, a in zip(targets, currents) } self._update_target = symjax.function(_tau, updates=updates) if tau is None: if not hasattr(self, "tau"): raise RuntimeError("tau must be specified") tau = tau or self.tau self._update_target(tau)
def __init__(self, *args, name=None, **kwargs): if name is None: name = self.__NAME__ with symjax.Scope(name): output = self.forward(*args, **kwargs) super().__init__(output, 0, _jax_function=jax.numpy.add)
def __init__(self, states, actions=None): self.state_shape = states.shape[1:] state = T.Placeholder((1, ) + states.shape[1:], "float32", name="critic_state") if actions: self.action_shape = actions.shape[1:] action = T.Placeholder((1, ) + actions.shape[1:], "float32", name="critic_action") action_shape = action.shape[1:] with symjax.Scope("critic"): q_values = self.create_network(states, actions) if q_values.ndim == 2: assert q_values.shape[1] == 1 q_values = q_values[:, 0] q_value = q_values.clone({states: state, actions: action}) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) inputs = [states, actions] input = [state, action] self.actions = actions self.action = action else: with symjax.Scope("critic"): q_values = self.create_network(states) if q_values.ndim == 2: assert q_values.shape[1] == 1 q_values = q_values[:, 0] q_value = q_values.clone({states: state}) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) inputs = [states] input = [state] self.q_values = q_values self.state = state self.states = states self._get_q_values = symjax.function(*inputs, outputs=q_values) self._get_q_value = symjax.function(*input, outputs=q_value[0])
def build_net(self, Q): # ------------------ all inputs ------------------------ state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s") next_state = T.Placeholder([self.batch_size, self.n_states], "float32", name="s_") reward = T.Placeholder( [ self.batch_size, ], "float32", name="r", ) # input reward action = T.Placeholder( [ self.batch_size, ], "int32", name="a", ) # input Action with symjax.Scope("eval_net"): q_eval = Q(state, self.n_actions) with symjax.Scope("test_set"): q_next = Q(next_state, self.n_actions) q_target = reward + self.reward_decay * q_next.max(1) q_target = T.stop_gradient(q_target) a_indices = T.stack([T.range(self.batch_size), action], axis=1) q_eval_wrt_a = T.take_along_axis(q_eval, action.reshape((-1, 1)), 1).squeeze(1) loss = T.mean((q_target - q_eval_wrt_a)**2) nn.optimizers.Adam(loss, self.lr) self.train = symjax.function(state, action, reward, next_state, updates=symjax.get_updates()) self.q_eval = symjax.function(state, outputs=q_eval)
def __new__(cls, *args, name=None, **kwargs): if name is None: name = cls.__NAME__ with symjax.Scope(name): output = cls.__init__(cls, *args, **kwargs) return output
def test_clone_0(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") with sj.Scope("placing"): u = T.Placeholder((), "float32", name="u") value = 2 * w * u c = value.clone({w: u}) f = sj.function(u, outputs=value) g = sj.function(u, outputs=c) assert np.array_equal([f(1), g(1), f(2), g(2)], [2, 2, 4, 8])
def __call__(self, action, episode): with symjax.Scope("OUProcess"): self.episode = T.Variable(1, "float32", name="episode", trainable=False) self.noise_scale = self.initial_noise_scale * self.noise_decay**episode x = (self.process + self.theta * (self.mean - self.process) * self.dt + self.std_dev * np.sqrt(self.dt) * np.random.normal(size=action.shape)) # Store x into process # Makes next noise dependent on current one self.process = x return action + self.noise_scale * self.process
def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, lr=1e-3, gamma=0.99, ): self.actor = actor self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes states = T.Placeholder((self.batch_size, ) + state_shape, "float32") actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size, ), "float32") self.actor = actor(states, distribution="gaussian") logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes with symjax.Scope("REINFORCE_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) # create the update function self._train = symjax.function( states, actions, discounted_rewards, outputs=actor_loss, updates=symjax.get_updates(scope="*/REINFORCE_optimizer"), )
def __init__(self, states, actions_distribution=None, name="actor"): self.state_shape = states.shape[1:] state = T.Placeholder((1, ) + states.shape[1:], "float32") self.actions_distribution = actions_distribution with symjax.Scope(name): if actions_distribution == symjax.probabilities.Normal: means, covs = self.create_network(states) actions = actions_distribution(means, cov=covs) samples = actions.sample() samples_log_prob = actions.log_prob(samples) action = symjax.probabilities.MultivariateNormal( means.clone({states: state}), cov=covs.clone({states: state}), ) sample = self.action.sample() sample_log_prob = self.action.log_prob(sample) self._get_actions = symjax.function( states, outputs=[samples, samples_log_prob]) self._get_action = symjax.function( state, outputs=[sample[0], sample_log_prob[0]], ) elif actions_distribution is None: actions = self.create_network(states) action = actions.clone({states: state}) self._get_actions = symjax.function(states, outputs=actions) self._get_action = symjax.function(state, outputs=action[0]) self._params = symjax.get_variables( trainable=None, scope=symjax.current_graph().scope_name) self.actions = actions self.state = state self.action = action
def __init__( self, env_fn, actor, critic, gamma=0.99, tau=0.01, lr=1e-3, batch_size=32, epsilon=0.1, epsilon_decay=1 / 1000, min_epsilon=0.01, reward=None, ): # comment out this line if you don't want to record a video of the agent # if save_folder is not None: # test_env = gym.wrappers.Monitor(test_env) # get size of state space and action space num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_max = env.action_space.high[0] else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.gamma = gamma self.continuous = continuous self.observ_min = np.clip(env.observation_space.low, -20, 20) self.observ_max = np.clip(env.observation_space.high, -20, 20) self.env = env self.reward = reward # state state = T.Placeholder((batch_size, num_states), "float32") gradients = T.Placeholder((batch_size, num_actions), "float32") action = T.Placeholder((batch_size, num_actions), "float32") target = T.Placeholder((batch_size, 1), "float32") with symjax.Scope("actor_critic"): scaled_out = action_max * actor(state) Q = critic(state, action) a_loss = -T.sum(gradients * scaled_out) q_loss = T.mean((Q - target)**2) nn.optimizers.Adam(a_loss + q_loss, lr) self.update = symjax.function( state, action, target, gradients, outputs=[a_loss, q_loss], updates=symjax.get_updates(), ) g = symjax.gradients(T.mean(Q), [action])[0] self.get_gradients = symjax.function(state, action, outputs=g) # also create the target variants with symjax.Scope("actor_critic_target"): scaled_out_target = action_max * actor(state) Q_target = critic(state, action) self.actor_predict = symjax.function(state, outputs=scaled_out) self.actor_predict_target = symjax.function(state, outputs=scaled_out_target) self.critic_predict = symjax.function(state, action, outputs=Q) self.critic_predict_target = symjax.function(state, action, outputs=Q_target) t_params = symjax.get_variables(scope="/actor_critic_target/*") params = symjax.get_variables(scope="/actor_critic/*") replacement = { t: tau * e + (1 - tau) * t for t, e in zip(t_params, params) } self.update_target = symjax.function(updates=replacement) single_state = T.Placeholder((1, num_states), "float32") if not continuous: scaled_out = clean_action.argmax(-1) self.act = symjax.function(single_state, outputs=scaled_out.clone( {state: single_state})[0])
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, train_pi_iters=80, train_v_iters=80, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.extras = {} state_ph = T.Placeholder((batch_size, num_states), "float32") rew_ph = T.Placeholder((batch_size, ), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if not continuous: pi = Categorical(logits=logits) else: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = pi.sample() # pi actions_log_prob = pi.log_prob(actions) # logp with symjax.Scope("critic"): critic_value = critic(state_ph) # AC objectives diff = rew_ph - critic_value actor_loss = -(actions_log_prob * diff).mean() critic_loss = nn.losses.squared_differences(rew_ph, critic_value).mean() with symjax.Scope("update_pi"): nn.optimizers.Adam( actor_loss, self.lr, params=symjax.get_variables(scope="/actor/"), ) with symjax.Scope("update_v"): nn.optimizers.Adam( critic_loss, self.lr, params=symjax.get_variables(scope="/critic/"), ) self.learn_pi = symjax.function( state_ph, rew_ph, outputs=actor_loss, updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, rew_ph, outputs=critic_loss, updates=symjax.get_updates(scope="/update_v/*"), ) single_state = T.Placeholder((1, num_states), "float32") single_action = actions.clone({state_ph: single_state})[0] single_v = critic_value.clone({state_ph: single_state}) self._act = symjax.function( single_state, outputs=[single_action, single_v], )
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, K_epochs=80, eps_clip=0.2, gamma=0.99, entropy_beta=0.01, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32", name="states") actions = T.Placeholder((batch_size, ) + actions_shape, "float32", name="states") rewards = T.Placeholder((batch_size, ), "float32", name="discounted_rewards") advantages = T.Placeholder((batch_size, ), "float32", name="advantages") self.target_actor = actor(states, distribution="gaussian") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) # Finding the ratio (pi_theta / pi_theta__old) and # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf with symjax.Scope("policy_loss"): ratios = T.exp( self.actor.actions.log_prob(actions) - self.target_actor.actions.log_prob(actions)) ratios = T.clip(ratios, 0, 10) clipped_ratios = T.clip(ratios, 1 - self.eps_clip, 1 + self.eps_clip) surr1 = advantages * ratios surr2 = advantages * clipped_ratios actor_loss = -(T.minimum(surr1, surr2)).mean() with symjax.Scope("monitor"): clipfrac = (((ratios > (1 + self.eps_clip)) | (ratios < (1 - self.eps_clip))).astype("float32").mean()) approx_kl = (self.target_actor.actions.log_prob(actions) - self.actor.actions.log_prob(actions)).mean() with symjax.Scope("critic_loss"): critic_loss = T.mean((rewards - self.critic.q_values)**2) with symjax.Scope("entropy"): entropy = self.actor.actions.entropy().mean() loss = actor_loss + critic_loss # - entropy_beta * entropy with symjax.Scope("optimizer"): nn.optimizers.Adam( loss, lr, params=self.actor.params(True) + self.critic.params(True), ) # create the update function self._train = symjax.function( states, actions, rewards, advantages, outputs=[actor_loss, critic_loss, clipfrac, approx_kl], updates=symjax.get_updates(scope="*optimizer"), ) # initialize target as current self.update_target(1)
def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, critic, lr=1e-3, gamma=0.99, train_v_iters=10, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes self.train_v_iters = train_v_iters states = T.Placeholder((self.batch_size, ) + state_shape, "float32") actions = T.Placeholder((self.batch_size, ) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size, ), "float32") advantages = T.Placeholder((self.batch_size, ), "float32") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * advantages).sum() / n_episodes critic_loss = 0.5 * ( (discounted_rewards - self.critic.q_values)**2).mean() with symjax.Scope("actor_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) with symjax.Scope("critic_optimizer"): nn.optimizers.Adam( critic_loss, lr, params=self.critic.params(True), ) # create the update function self._train_actor = symjax.function( states, actions, advantages, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # create the update function self._train_critic = symjax.function( states, discounted_rewards, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), )
def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, gamma=0.99, tau=0.01, ): self.gamma = gamma self.tau = tau self.lr = lr self.batch_size = batch_size states = T.Placeholder((batch_size, ) + state_shape, "float32") actions = T.Placeholder((batch_size, ) + actions_shape, "float32") self.critic = critic(states, actions) self.target_critic = critic(states, actions) # create critic loss targets = T.Placeholder(self.critic.q_values.shape, "float32") critic_loss = ((self.critic.q_values - targets)**2).mean() # create optimizer with symjax.Scope("critic_optimizer"): nn.optimizers.Adam(critic_loss, lr, params=self.critic.params(True)) # create the update function self._train_critic = symjax.function( states, actions, targets, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), ) # now create utility function to get the gradients grad = symjax.gradients(self.critic.q_values.sum(), actions) self._get_critic_gradients = symjax.function(states, actions, outputs=grad) # create actor loss self.actor = actor(states) self.target_actor = actor(states) gradients = T.Placeholder(actions.shape, "float32") actor_loss = -(self.actor.actions * gradients).mean() # create optimizer with symjax.Scope("actor_optimizer"): nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True)) # create the update function self._train_actor = symjax.function( states, gradients, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # initialize both networks as the same self.update_target(1)
def __init__( self, env, actor, critic, lr=1e-4, batch_size=32, n=1, clip_ratio=0.2, entcoeff=0.01, target_kl=0.01, train_pi_iters=4, train_v_iters=4, ): num_states = env.observation_space.shape[0] continuous = type(env.action_space) == gym.spaces.box.Box if continuous: num_actions = env.action_space.shape[0] action_min = env.action_space.low action_max = env.action_space.high else: num_actions = env.action_space.n action_max = 1 self.batch_size = batch_size self.num_states = num_states self.num_actions = num_actions self.state_dim = (batch_size, num_states) self.action_dim = (batch_size, num_actions) self.continuous = continuous self.lr = lr self.train_pi_iters = train_pi_iters self.train_v_iters = train_v_iters self.clip_ratio = clip_ratio self.target_kl = target_kl self.extras = {"logprob": ()} self.entcoeff = entcoeff state_ph = T.Placeholder((batch_size, num_states), "float32") ret_ph = T.Placeholder((batch_size, ), "float32") adv_ph = T.Placeholder((batch_size, ), "float32") act_ph = T.Placeholder((batch_size, num_actions), "float32") with symjax.Scope("actor"): logits = actor(state_ph) if continuous: logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) with symjax.Scope("old_actor"): old_logits = actor(state_ph) if continuous: old_logstd = T.Variable( -0.5 * np.ones(num_actions, dtype=np.float32), name="logstd", ) if not continuous: pi = Categorical(logits=logits) else: pi = MultivariateNormal(mean=logits, diag_log_std=logstd) actions = T.clip(pi.sample(), -2, 2) # pi actor_params = actor_params = symjax.get_variables(scope="/actor/") old_actor_params = actor_params = symjax.get_variables( scope="/old_actor/") self.update_target = symjax.function( updates={o: a for o, a in zip(old_actor_params, actor_params)}) # PPO objectives # pi(a|s) / pi_old(a|s) pi_log_prob = pi.log_prob(act_ph) old_pi_log_prob = pi_log_prob.clone({ logits: old_logits, logstd: old_logstd }) ratio = T.exp(pi_log_prob - old_pi_log_prob) surr1 = ratio * adv_ph surr2 = adv_ph * T.clip(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) pi_loss = -T.minimum(surr1, surr2).mean() # ent_loss = pi.entropy().mean() * self.entcoeff with symjax.Scope("critic"): v = critic(state_ph) # critic loss v_loss = ((ret_ph - v)**2).mean() # Info (useful to watch during learning) # a sample estimate for KL approx_kl = (old_pi_log_prob - pi_log_prob).mean() # a sample estimate for entropy # approx_ent = -logprob_given_actions.mean() # clipped = T.logical_or( # ratio > (1 + clip_ratio), ratio < (1 - clip_ratio) # ) # clipfrac = clipped.astype("float32").mean() with symjax.Scope("update_pi"): print(len(actor_params), "actor parameters") nn.optimizers.Adam( pi_loss, self.lr, params=actor_params, ) with symjax.Scope("update_v"): critic_params = symjax.get_variables(scope="/critic/") print(len(critic_params), "critic parameters") nn.optimizers.Adam( v_loss, self.lr, params=critic_params, ) self.get_params = symjax.function(outputs=critic_params) self.learn_pi = symjax.function( state_ph, act_ph, adv_ph, outputs=[pi_loss, approx_kl], updates=symjax.get_updates(scope="/update_pi/"), ) self.learn_v = symjax.function( state_ph, ret_ph, outputs=v_loss, updates=symjax.get_updates(scope="/update_v/"), ) single_state = T.Placeholder((1, num_states), "float32") single_v = v.clone({state_ph: single_state}) single_sample = actions.clone({state_ph: single_state}) self._act = symjax.function(single_state, outputs=single_sample) self._get_v = symjax.function(single_state, outputs=single_v) single_action = T.Placeholder((1, num_actions), "float32")
def __init__(self, *args, name=None, **kwargs): if name is None: name = self.__NAME__ with symjax.Scope(name): self.create_updates(*args, **kwargs)