def __init__(self, name, param_store=None, tensorboard_dir=None): env = make_env(name, tensorboard_dir) # function approximator self.q = coax.Q(forward_pass, env) self.q_targ = self.q.copy() # tracer and updater self.q_updater = coax.td_learning.QLearning(self.q, q_targ=self.q_targ, optimizer=optax.adam(3e-4)) # schedule for beta parameter used in PrioritizedReplayBuffer self.buffer_beta = coax.utils.StepwiseLinearFunction((0, 0.4), (1000000, 1)) super().__init__( env=env, param_store=param_store, pi=coax.BoltzmannPolicy(self.q, temperature=0.015), tracer=coax.reward_tracing.NStep(n=1, gamma=0.99), buffer=(coax.experience_replay.PrioritizedReplayBuffer( capacity=1000000, alpha=0.6) if param_store is None else None), buffer_warmup=50000, name=name)
# custom haiku function: s,a -> q(s,a) value = hk.Sequential([...]) X = jax.vmap(jnp.kron)( S, A) # or jnp.concatenate((S, A), axis=-1) or whatever you like return value(X) # output shape: (batch_size,) def func_type2(S, is_training): # custom haiku function: s -> q(s,.) value = hk.Sequential([...]) return value(S) # output shape: (batch_size, num_actions) # function approximator func = ... # func_type1 or func_type2 q = coax.Q(func, env) pi = coax.EpsilonGreedy(q, epsilon=0.1) # target network q_targ = q.copy() # specify how to update q-function qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=optax.adam(0.001)) # specify how to trace the transitions tracer = coax.reward_tracing.NStep(n=1, gamma=0.9) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000) for ep in range(100):
return {'logits': logits(X)} def func_q(S, A, is_training): value = hk.Sequential( (hk.Linear(256), jax.nn.relu, hk.Linear(1, w_init=jnp.zeros), jnp.ravel)) X = shared(S, is_training) assert A.ndim == 2 and A.shape[ 1] == env.action_space.n, "actions must be one-hot encoded" return value(jax.vmap(jnp.kron)(X, A)) # function approximators pi = coax.Policy(func_pi, env) q = coax.Q(func_q, env) # target networks pi_targ = pi.copy() q_targ = q.copy() # policy regularizer (avoid premature exploitation) kl_div = coax.regularizers.KLDivRegularizer(pi, beta=0.001) # updaters qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(3e-4)) determ_pg = coax.policy_objectives.DeterministicPG(pi, q, regularizer=kl_div, optimizer=adam(3e-4))
def func_pi(S, is_training): # custom haiku function (for continuous actions in this example) mu = hk.Sequential([...])(S) # mu.shape: (batch_size, *action_space.shape) return {'mu': mu, 'logvar': jnp.full_like(mu, -10)} # deterministic policy def func_q(S, A, is_training): # custom haiku function value = hk.Sequential([...]) return value(S) # output shape: (batch_size,) # define function approximator pi = coax.Policy(func_pi, env) q1 = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate) q2 = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate) # target networks pi_targ = pi.copy() q1_targ = q1.copy() q2_targ = q2.copy() # specify how to update policy and value function determ_pg = coax.policy_objectives.DeterministicPG(pi, q1, optimizer=optax.adam(0.001)) qlearning1 = coax.td_learning.ClippedDoubleQLearning( q1, q_targ_list=[q1_targ, q2_targ], optimizer=optax.adam(0.001)) qlearning2 = coax.td_learning.ClippedDoubleQLearning( q2, q_targ_list=[q1_targ, q2_targ], optimizer=optax.adam(0.001))
def func_pi(S, is_training): logits = hk.Linear(env.action_space.n, w_init=jnp.zeros) return {'logits': logits(S)} def func_q(S, A, is_training): value = hk.Sequential((hk.Flatten(), hk.Linear(1, w_init=jnp.zeros), jnp.ravel)) X = jax.vmap(jnp.kron)(S, A) # S and A are one-hot encoded return value(X) # function approximators pi = coax.Policy(func_pi, env) q1 = coax.Q(func_q, env) q2 = coax.Q(func_q, env) # target networks q1_targ = q1.copy() q2_targ = q2.copy() pi_targ = pi.copy() # experience tracer tracer = coax.reward_tracing.NStep(n=1, gamma=0.9) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=128) # updaters