Exemplo n.º 1
0
    def __init__(self, name, param_store=None, tensorboard_dir=None):
        env = make_env(name, tensorboard_dir)

        # function approximator
        self.q = coax.Q(forward_pass, env)
        self.q_targ = self.q.copy()

        # tracer and updater
        self.q_updater = coax.td_learning.QLearning(self.q,
                                                    q_targ=self.q_targ,
                                                    optimizer=optax.adam(3e-4))

        # schedule for beta parameter used in PrioritizedReplayBuffer
        self.buffer_beta = coax.utils.StepwiseLinearFunction((0, 0.4),
                                                             (1000000, 1))

        super().__init__(
            env=env,
            param_store=param_store,
            pi=coax.BoltzmannPolicy(self.q, temperature=0.015),
            tracer=coax.reward_tracing.NStep(n=1, gamma=0.99),
            buffer=(coax.experience_replay.PrioritizedReplayBuffer(
                capacity=1000000, alpha=0.6) if param_store is None else None),
            buffer_warmup=50000,
            name=name)
Exemplo n.º 2
0
    # custom haiku function: s,a -> q(s,a)
    value = hk.Sequential([...])
    X = jax.vmap(jnp.kron)(S, A)  # or jnp.concatenate((S, A), axis=-1) or whatever you like
    return value(X)  # output shape: (batch_size,)


def func_type2(S, is_training):
    # custom haiku function: s -> q(s,.)
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size, num_actions)


# function approximator
func = ...  # func_type1 or func_type2
q = coax.Q(func, env)
pi = coax.BoltzmannPolicy(q, temperature=0.1)


# specify how to update q-function
qlearning = coax.td_learning.SoftQLearning(q, optimizer=adam(0.001), temperature=pi.temperature)


# specify how to trace the transitions
cache = coax.reward_tracing.NStep(n=1, gamma=0.9)


for ep in range(100):
    pi.epsilon = ...  # exploration schedule
    s = env.reset()

    for t in range(env.spec.max_episode_steps):
Exemplo n.º 3
0
        hk.Conv2D(16, kernel_shape=8, stride=4),
        jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2),
        jax.nn.relu,
        hk.Flatten(),
    ))
    head = hk.Sequential(
        (hk.Linear(256), jax.nn.relu, hk.Linear(1,
                                                w_init=jnp.zeros), jnp.ravel))
    X = jnp.stack(S, axis=-1) / 255.  # stack frames
    return head(jax.vmap(jnp.kron)(body(X), A))


# function approximator
q = coax.Q(func, env)
pi = coax.BoltzmannPolicy(
    q, temperature=0.015)  # <--- different from standard DQN

# target network
q_targ = q.copy()

# updater
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(3e-4))

# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)

while env.T < 3000000:
    s = env.reset()

    for t in range(env.spec.max_episode_steps):
Exemplo n.º 4
0
        rng=hk.next_rng_key(),
        batch_size=jax.tree_leaves(S)[0].shape[0],
        num_quantiles=num_quantiles)
    X = jax.vmap(jnp.kron)(S, A)
    x = encoder(X)
    quantile_x = quantile_net(x, quantile_fractions=quantile_fractions)
    quantile_values = hk.Linear(1, w_init=jnp.zeros)(quantile_x)
    return {
        'values': quantile_values.squeeze(axis=-1),
        'quantile_fractions': quantile_fractions
    }


# quantile value function and its derived policy
q = coax.StochasticQ(func, env, num_bins=num_quantiles, value_range=None)
pi = coax.BoltzmannPolicy(q)

# target network
q_targ = q.copy()

# experience tracer
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)

# updater
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(0.001))

# train
for ep in range(1000):
    s = env.reset()
    # pi.epsilon = max(0.01, pi.epsilon * 0.95)