Ejemplo n.º 1
0

def func_p(S, A, is_training):
    dS = hk.Linear(4, w_init=jnp.zeros)
    return S + dS(A)


def func_r(S, A, is_training):
    return jnp.ones(
        S.shape[0]
    )  # CartPole yields r=1 at every time step (no need to learn)


# function approximators
p = coax.TransitionModel(func_p, env)
v = coax.V(func_v, env, observation_preprocessor=p.observation_preprocessor)
r = coax.RewardFunction(func_r,
                        env,
                        observation_preprocessor=p.observation_preprocessor)

# composite objects
q = coax.SuccessorStateQ(v, p, r, gamma=0.9)
pi = coax.EpsilonGreedy(q, epsilon=0.)  # no exploration

# reward tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=q.gamma)

# updaters
adam = optax.chain(optax.apply_every(k=16), optax.adam(1e-4))
simple_td = coax.td_learning.SimpleTD(v, loss_function=mse, optimizer=adam)
Ejemplo n.º 2
0

def func_v(S, is_training):
    # custom haiku function
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size,)


def func_pi(S, is_training):
    # custom haiku function (for discrete actions in this example)
    logits = hk.Sequential([...])
    return {'logits': logits(S)}  # logits shape: (batch_size, num_actions)


# function approximators
v = coax.V(func_v, env)
pi = coax.Policy(func_pi, env)


# slow-moving avg of pi
pi_behavior = pi.copy()


# specify how to update policy and value function
ppo_clip = coax.policy_objectives.PPOClip(pi, optimizer=optax.adam(0.001))
simple_td = coax.td_learning.SimpleTD(v, optimizer=optax.adam(0.001))


# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)