def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: node_key, edge_key = hk.next_rng_keys(2) nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes) edges = graph.edges if not self._disable_edge_updates: edges = hk.dropout(edge_key, self._dropout_rate, edges) return graph._replace(nodes=nodes, edges=edges)
def q(S, A, is_training): rng1, rng2, rng3 = hk.next_rng_keys(3) rate = hparams.dropout_critic * is_training seq = hk.Sequential(( hk.Linear(hparams.h1_critic), jax.nn.relu, partial(hk.dropout, rng1, rate), hk.Linear(hparams.h2_critic), jax.nn.relu, partial(hk.dropout, rng2, rate), hk.Linear(hparams.h3_critic), jax.nn.relu, partial(hk.dropout, rng3, rate), hk.Linear(1), jnp.ravel, )) flatten = hk.Flatten() X_sa = jnp.concatenate([flatten(S), jnp.tanh(flatten(A))], axis=1) return seq(X_sa)
def pi(S, is_training): rng1, rng2, rng3 = hk.next_rng_keys(3) shape = env.action_space.shape rate = hparams.dropout_actor * is_training seq = hk.Sequential(( hk.Linear(hparams.h1_actor), jax.nn.relu, partial(hk.dropout, rng1, rate), hk.Linear(hparams.h2_actor), jax.nn.relu, partial(hk.dropout, rng2, rate), hk.Linear(hparams.h3_actor), jax.nn.relu, partial(hk.dropout, rng3, rate), hk.Linear(onp.prod(shape)), hk.Reshape(shape), # lambda x: low + (high - low) * jax.nn.sigmoid(x), # disable: BoxActionsToReals )) return seq(S) # batch of actions
def func(S, is_training): rng1, rng2, rng3 = hk.next_rng_keys(3) rate = 0.25 if is_training else 0. batch_norm = hk.BatchNorm(False, False, 0.99) seq = hk.Sequential(( hk.Flatten(), hk.Linear(8), jax.nn.relu, partial(hk.dropout, rng1, rate), partial(batch_norm, is_training=is_training), hk.Linear(8), jax.nn.relu, partial(hk.dropout, rng2, rate), partial(batch_norm, is_training=is_training), hk.Linear(8), jax.nn.relu, partial(hk.dropout, rng3, rate), partial(batch_norm, is_training=is_training), hk.Linear(1, w_init=jnp.zeros), jnp.ravel, )) return seq(S)