Esempio n. 1
0
    value = hk.Sequential([...])
    X = jax.vmap(jnp.kron)(
        S, A)  # or jnp.concatenate((S, A), axis=-1) or whatever you like
    return value(X)  # output shape: (batch_size,)


def func_type2(S, is_training):
    # custom haiku function: s -> q(s,.)
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size, num_actions)


# function approximator
func = ...  # func_type1 or func_type2
q = coax.Q(func, env)
pi = coax.EpsilonGreedy(q, epsilon=0.1)

# target network
q_targ = q.copy()

# specify how to update q-function
qlearning = coax.td_learning.QLearning(q,
                                       q_targ=q_targ,
                                       optimizer=optax.adam(0.001))

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)

for ep in range(100):
    pi.epsilon = ...  # exploration schedule
Esempio n. 2
0
        hk.Conv2D(16, kernel_shape=8, stride=4),
        jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(256),
        jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = jnp.stack(S, axis=-1) / 255.  # stack frames
    return seq(X)


# function approximator
q = coax.Q(func, env)
pi = coax.EpsilonGreedy(q, epsilon=1.)

# target network
q_targ = q.copy()

# updater
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(3e-4))

# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)

# DQN exploration schedule (stepwise linear annealing)
epsilon = coax.utils.StepwiseLinearFunction((0, 1), (1000000, 0.1),
                                            (2000000, 0.01))
Esempio n. 3
0
def func_r(S, A, is_training):
    return jnp.ones(
        S.shape[0]
    )  # CartPole yields r=1 at every time step (no need to learn)


# function approximators
p = coax.TransitionModel(func_p, env)
v = coax.V(func_v, env, observation_preprocessor=p.observation_preprocessor)
r = coax.RewardFunction(func_r,
                        env,
                        observation_preprocessor=p.observation_preprocessor)

# composite objects
q = coax.SuccessorStateQ(v, p, r, gamma=0.9)
pi = coax.EpsilonGreedy(q, epsilon=0.)  # no exploration

# reward tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=q.gamma)

# updaters
adam = optax.chain(optax.apply_every(k=16), optax.adam(1e-4))
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=adam)

sgd = optax.sgd(1e-3, momentum=0.9, nesterov=True)
model_updater = coax.model_updaters.ModelUpdater(p, optimizer=sgd)

while env.T < 100000:
    s = env.reset()
    env.render()
Esempio n. 4
0
    value = hk.Sequential([...])
    X = jax.vmap(jnp.kron)(
        S, A)  # or jnp.concatenate((S, A), axis=-1) or whatever you like
    return value(X)  # output shape: (batch_size,)


def func_type2(S, is_training):
    # custom haiku function: s -> q(s,.)
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size, num_actions)


# function approximator
func = ...  # func_type1 or func_type2
q = coax.Q(func, env)
pi = coax.EpsilonGreedy(q, epsilon=1.0)  # epsilon will be updated

# target network
q_targ = q.copy()

# specify how to update q-function
qlearning = coax.td_learning.QLearning(q,
                                       q_targ=q_targ,
                                       optimizer=optax.adam(0.001))

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.PrioritizedReplayBuffer(capacity=1000000,
                                                        alpha=0.6,
                                                        beta=0.4)