Esempio n. 1
0
    def __init__(self, name, param_store=None, tensorboard_dir=None):
        env = make_env(name, tensorboard_dir)

        # function approximator
        self.q = coax.Q(forward_pass, env)
        self.q_targ = self.q.copy()

        # tracer and updater
        self.q_updater = coax.td_learning.QLearning(self.q,
                                                    q_targ=self.q_targ,
                                                    optimizer=optax.adam(3e-4))

        # schedule for beta parameter used in PrioritizedReplayBuffer
        self.buffer_beta = coax.utils.StepwiseLinearFunction((0, 0.4),
                                                             (1000000, 1))

        super().__init__(
            env=env,
            param_store=param_store,
            pi=coax.BoltzmannPolicy(self.q, temperature=0.015),
            tracer=coax.reward_tracing.NStep(n=1, gamma=0.99),
            buffer=(coax.experience_replay.PrioritizedReplayBuffer(
                capacity=1000000, alpha=0.6) if param_store is None else None),
            buffer_warmup=50000,
            name=name)
Esempio n. 2
0
    # custom haiku function: s,a -> q(s,a)
    value = hk.Sequential([...])
    X = jax.vmap(jnp.kron)(
        S, A)  # or jnp.concatenate((S, A), axis=-1) or whatever you like
    return value(X)  # output shape: (batch_size,)


def func_type2(S, is_training):
    # custom haiku function: s -> q(s,.)
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size, num_actions)


# function approximator
func = ...  # func_type1 or func_type2
q = coax.Q(func, env)
pi = coax.EpsilonGreedy(q, epsilon=0.1)

# target network
q_targ = q.copy()

# specify how to update q-function
qlearning = coax.td_learning.QLearning(q,
                                       q_targ=q_targ,
                                       optimizer=optax.adam(0.001))

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)

for ep in range(100):
Esempio n. 3
0
    return {'logits': logits(X)}


def func_q(S, A, is_training):
    value = hk.Sequential(
        (hk.Linear(256), jax.nn.relu, hk.Linear(1,
                                                w_init=jnp.zeros), jnp.ravel))
    X = shared(S, is_training)
    assert A.ndim == 2 and A.shape[
        1] == env.action_space.n, "actions must be one-hot encoded"
    return value(jax.vmap(jnp.kron)(X, A))


# function approximators
pi = coax.Policy(func_pi, env)
q = coax.Q(func_q, env)

# target networks
pi_targ = pi.copy()
q_targ = q.copy()

# policy regularizer (avoid premature exploitation)
kl_div = coax.regularizers.KLDivRegularizer(pi, beta=0.001)

# updaters
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(3e-4))
determ_pg = coax.policy_objectives.DeterministicPG(pi,
                                                   q,
                                                   regularizer=kl_div,
                                                   optimizer=adam(3e-4))
Esempio n. 4
0
def func_pi(S, is_training):
    # custom haiku function (for continuous actions in this example)
    mu = hk.Sequential([...])(S)  # mu.shape: (batch_size, *action_space.shape)
    return {'mu': mu, 'logvar': jnp.full_like(mu, -10)}  # deterministic policy


def func_q(S, A, is_training):
    # custom haiku function
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size,)


# define function approximator
pi = coax.Policy(func_pi, env)
q1 = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)
q2 = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)


# target networks
pi_targ = pi.copy()
q1_targ = q1.copy()
q2_targ = q2.copy()


# specify how to update policy and value function
determ_pg = coax.policy_objectives.DeterministicPG(pi, q1, optimizer=optax.adam(0.001))
qlearning1 = coax.td_learning.ClippedDoubleQLearning(
    q1, q_targ_list=[q1_targ, q2_targ], optimizer=optax.adam(0.001))
qlearning2 = coax.td_learning.ClippedDoubleQLearning(
    q2, q_targ_list=[q1_targ, q2_targ], optimizer=optax.adam(0.001))
Esempio n. 5
0

def func_pi(S, is_training):
    logits = hk.Linear(env.action_space.n, w_init=jnp.zeros)
    return {'logits': logits(S)}


def func_q(S, A, is_training):
    value = hk.Sequential((hk.Flatten(), hk.Linear(1, w_init=jnp.zeros), jnp.ravel))
    X = jax.vmap(jnp.kron)(S, A)  # S and A are one-hot encoded
    return value(X)


# function approximators
pi = coax.Policy(func_pi, env)
q1 = coax.Q(func_q, env)
q2 = coax.Q(func_q, env)


# target networks
q1_targ = q1.copy()
q2_targ = q2.copy()
pi_targ = pi.copy()


# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=128)


# updaters