def __init__(self, name, param_store=None, tensorboard_dir=None): env = make_env(name, tensorboard_dir) # function approximator self.q = coax.Q(forward_pass, env) self.q_targ = self.q.copy() # tracer and updater self.q_updater = coax.td_learning.QLearning(self.q, q_targ=self.q_targ, optimizer=optax.adam(3e-4)) # schedule for beta parameter used in PrioritizedReplayBuffer self.buffer_beta = coax.utils.StepwiseLinearFunction((0, 0.4), (1000000, 1)) super().__init__( env=env, param_store=param_store, pi=coax.BoltzmannPolicy(self.q, temperature=0.015), tracer=coax.reward_tracing.NStep(n=1, gamma=0.99), buffer=(coax.experience_replay.PrioritizedReplayBuffer( capacity=1000000, alpha=0.6) if param_store is None else None), buffer_warmup=50000, name=name)
# custom haiku function: s,a -> q(s,a) value = hk.Sequential([...]) X = jax.vmap(jnp.kron)(S, A) # or jnp.concatenate((S, A), axis=-1) or whatever you like return value(X) # output shape: (batch_size,) def func_type2(S, is_training): # custom haiku function: s -> q(s,.) value = hk.Sequential([...]) return value(S) # output shape: (batch_size, num_actions) # function approximator func = ... # func_type1 or func_type2 q = coax.Q(func, env) pi = coax.BoltzmannPolicy(q, temperature=0.1) # specify how to update q-function qlearning = coax.td_learning.SoftQLearning(q, optimizer=adam(0.001), temperature=pi.temperature) # specify how to trace the transitions cache = coax.reward_tracing.NStep(n=1, gamma=0.9) for ep in range(100): pi.epsilon = ... # exploration schedule s = env.reset() for t in range(env.spec.max_episode_steps):
hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu, hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu, hk.Flatten(), )) head = hk.Sequential( (hk.Linear(256), jax.nn.relu, hk.Linear(1, w_init=jnp.zeros), jnp.ravel)) X = jnp.stack(S, axis=-1) / 255. # stack frames return head(jax.vmap(jnp.kron)(body(X), A)) # function approximator q = coax.Q(func, env) pi = coax.BoltzmannPolicy( q, temperature=0.015) # <--- different from standard DQN # target network q_targ = q.copy() # updater qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(3e-4)) # reward tracer and replay buffer tracer = coax.reward_tracing.NStep(n=1, gamma=0.99) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000) while env.T < 3000000: s = env.reset() for t in range(env.spec.max_episode_steps):
rng=hk.next_rng_key(), batch_size=jax.tree_leaves(S)[0].shape[0], num_quantiles=num_quantiles) X = jax.vmap(jnp.kron)(S, A) x = encoder(X) quantile_x = quantile_net(x, quantile_fractions=quantile_fractions) quantile_values = hk.Linear(1, w_init=jnp.zeros)(quantile_x) return { 'values': quantile_values.squeeze(axis=-1), 'quantile_fractions': quantile_fractions } # quantile value function and its derived policy q = coax.StochasticQ(func, env, num_bins=num_quantiles, value_range=None) pi = coax.BoltzmannPolicy(q) # target network q_targ = q.copy() # experience tracer tracer = coax.reward_tracing.NStep(n=1, gamma=0.9) buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000) # updater qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(0.001)) # train for ep in range(1000): s = env.reset() # pi.epsilon = max(0.01, pi.epsilon * 0.95)