Пример #1
0
 def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
     node_key, edge_key = hk.next_rng_keys(2)
     nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes)
     edges = graph.edges
     if not self._disable_edge_updates:
         edges = hk.dropout(edge_key, self._dropout_rate, edges)
     return graph._replace(nodes=nodes, edges=edges)
Пример #2
0
def q(S, A, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    rate = hparams.dropout_critic * is_training
    seq = hk.Sequential((
        hk.Linear(hparams.h1_critic),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        hk.Linear(hparams.h2_critic),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        hk.Linear(hparams.h3_critic),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        hk.Linear(1),
        jnp.ravel,
    ))
    flatten = hk.Flatten()
    X_sa = jnp.concatenate([flatten(S), jnp.tanh(flatten(A))], axis=1)
    return seq(X_sa)
Пример #3
0
def pi(S, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    shape = env.action_space.shape
    rate = hparams.dropout_actor * is_training
    seq = hk.Sequential((
        hk.Linear(hparams.h1_actor),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        hk.Linear(hparams.h2_actor),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        hk.Linear(hparams.h3_actor),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        hk.Linear(onp.prod(shape)),
        hk.Reshape(shape),
        # lambda x: low + (high - low) * jax.nn.sigmoid(x),  # disable: BoxActionsToReals
    ))
    return seq(S)  # batch of actions
Пример #4
0
def func(S, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    rate = 0.25 if is_training else 0.
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        partial(batch_norm, is_training=is_training),
        hk.Linear(1, w_init=jnp.zeros),
        jnp.ravel,
    ))
    return seq(S)