def func_p(S, A, is_training): dS = hk.Linear(4, w_init=jnp.zeros) return S + dS(A) def func_r(S, A, is_training): return jnp.ones( S.shape[0] ) # CartPole yields r=1 at every time step (no need to learn) # function approximators p = coax.TransitionModel(func_p, env) v = coax.V(func_v, env, observation_preprocessor=p.observation_preprocessor) r = coax.RewardFunction(func_r, env, observation_preprocessor=p.observation_preprocessor) # composite objects q = coax.SuccessorStateQ(v, p, r, gamma=0.9) pi = coax.EpsilonGreedy(q, epsilon=0.) # no exploration # reward tracer tracer = coax.reward_tracing.NStep(n=1, gamma=q.gamma) # updaters adam = optax.chain(optax.apply_every(k=16), optax.adam(1e-4)) simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=adam)
def func_v(S, is_training): # custom haiku function value = hk.Sequential([...]) return value(S) # output shape: (batch_size,) def func_pi(S, is_training): # custom haiku function (for discrete actions in this example) logits = hk.Sequential([...]) return {'logits': logits(S)} # logits shape: (batch_size, num_actions) # function approximators v = coax.V(func_v, env) pi = coax.Policy(func_pi, env) # slow-moving avg of pi pi_behavior = pi.copy() # specify how to update policy and value function ppo_clip = coax.policy_objectives.PPOClip(pi, optimizer=optax.adam(0.001)) simple_td = coax.td_learning.SimpleTD(v, optimizer=optax.adam(0.001)) # specify how to trace the transitions tracer = coax.reward_tracing.NStep(n=5, gamma=0.9) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)