Exemplo n.º 1
0
 def _actor_fn(obs: types.NestedArray) -> types.NestedArray:
     network = hk.Sequential([
         networks_lib.LayerNormMLP(hidden_layer_sizes, activate_final=True),
         networks_lib.NearZeroInitializedLinear(num_dimensions),
         networks_lib.TanhToSpec(spec.actions),
     ])
     return network(obs)
Exemplo n.º 2
0
 def _critic_fn(obs: types.NestedArray,
                action: types.NestedArray) -> types.NestedArray:
     network1 = hk.Sequential([
         networks_lib.LayerNormMLP(list(hidden_layer_sizes) + [1]),
     ])
     input_ = jnp.concatenate([obs, action], axis=-1)
     value = network1(input_)
     return jnp.squeeze(value)
Exemplo n.º 3
0
 def _critic_fn(obs, action):
     network = hk.Sequential([
         utils.batch_concat,
         networks_lib.LayerNormMLP(
             layer_sizes=[*critic_layer_sizes, num_atoms]),
     ])
     value = network([obs, action])
     return value, critic_atoms
Exemplo n.º 4
0
 def _actor_fn(obs):
     network = hk.Sequential([
         utils.batch_concat,
         networks_lib.LayerNormMLP(policy_layer_sizes, activate_final=True),
         networks_lib.NearZeroInitializedLinear(num_dimensions),
         networks_lib.TanhToSpec(action_spec),
     ])
     return network(obs)
Exemplo n.º 5
0
 def _actor_fn(obs, is_training=False, key=None):
     # is_training and key allows to defined train/test dependant modules
     # like dropout.
     del is_training
     del key
     if discrete_actions:
         network = hk.nets.MLP([64, 64, final_layer_size])
     else:
         network = hk.Sequential([
             networks_lib.LayerNormMLP([64, 64], activate_final=True),
             networks_lib.NormalTanhDistribution(final_layer_size),
         ])
     return network(obs)
Exemplo n.º 6
0
 def _rnd_fn(obs, act):
     # RND does not use the action but other variants like RED do.
     del act
     network = networks_lib.LayerNormMLP(list(layer_sizes))
     return network(obs)