value = hk.Sequential([...]) X = jax.vmap(jnp.kron)( S, A) # or jnp.concatenate((S, A), axis=-1) or whatever you like return value(X) # output shape: (batch_size,) def func_type2(S, is_training): # custom haiku function: s -> q(s,.) value = hk.Sequential([...]) return value(S) # output shape: (batch_size, num_actions) # function approximator func = ... # func_type1 or func_type2 q = coax.Q(func, env) pi = coax.EpsilonGreedy(q, epsilon=0.1) # target network q_targ = q.copy() # specify how to update q-function qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=optax.adam(0.001)) # specify how to trace the transitions tracer = coax.reward_tracing.NStep(n=1, gamma=0.9) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000) for ep in range(100): pi.epsilon = ... # exploration schedule
hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu, hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu, hk.Flatten(), hk.Linear(256), jax.nn.relu, hk.Linear(env.action_space.n, w_init=jnp.zeros), )) X = jnp.stack(S, axis=-1) / 255. # stack frames return seq(X) # function approximator q = coax.Q(func, env) pi = coax.EpsilonGreedy(q, epsilon=1.) # target network q_targ = q.copy() # updater qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(3e-4)) # reward tracer and replay buffer tracer = coax.reward_tracing.NStep(n=1, gamma=0.99) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000) # DQN exploration schedule (stepwise linear annealing) epsilon = coax.utils.StepwiseLinearFunction((0, 1), (1000000, 0.1), (2000000, 0.01))
def func_r(S, A, is_training): return jnp.ones( S.shape[0] ) # CartPole yields r=1 at every time step (no need to learn) # function approximators p = coax.TransitionModel(func_p, env) v = coax.V(func_v, env, observation_preprocessor=p.observation_preprocessor) r = coax.RewardFunction(func_r, env, observation_preprocessor=p.observation_preprocessor) # composite objects q = coax.SuccessorStateQ(v, p, r, gamma=0.9) pi = coax.EpsilonGreedy(q, epsilon=0.) # no exploration # reward tracer tracer = coax.reward_tracing.NStep(n=1, gamma=q.gamma) # updaters adam = optax.chain(optax.apply_every(k=16), optax.adam(1e-4)) simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=adam) sgd = optax.sgd(1e-3, momentum=0.9, nesterov=True) model_updater = coax.model_updaters.ModelUpdater(p, optimizer=sgd) while env.T < 100000: s = env.reset() env.render()
value = hk.Sequential([...]) X = jax.vmap(jnp.kron)( S, A) # or jnp.concatenate((S, A), axis=-1) or whatever you like return value(X) # output shape: (batch_size,) def func_type2(S, is_training): # custom haiku function: s -> q(s,.) value = hk.Sequential([...]) return value(S) # output shape: (batch_size, num_actions) # function approximator func = ... # func_type1 or func_type2 q = coax.Q(func, env) pi = coax.EpsilonGreedy(q, epsilon=1.0) # epsilon will be updated # target network q_targ = q.copy() # specify how to update q-function qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=optax.adam(0.001)) # specify how to trace the transitions tracer = coax.reward_tracing.NStep(n=1, gamma=0.9) buffer = coax.experience_replay.PrioritizedReplayBuffer(capacity=1000000, alpha=0.6, beta=0.4)