Ejemplo n.º 1
0
    def _actor_fn(obs):
        # # for matching Ilya's codebase
        # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5)
        # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2)
        # x = obs
        # for hid_dim in actor_hidden_layer_sizes:
        #   x = hk.Linear(hid_dim, w_init=relu_orthogonal, b_init=jnp.zeros)(x)
        #   x = jax.nn.relu(x)
        # dist = networks_lib.NormalTanhDistribution(
        #     num_dimensions,
        #     w_init=near_zero_orthogonal,
        #     b_init=jnp.zeros)(x)
        # return dist

        network = hk.Sequential([
            hk.nets.MLP(
                list(actor_hidden_layer_sizes),
                # w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
                # w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"),
                w_init=w_init,
                b_init=b_init,
                activation=jax.nn.relu,
                activate_final=True),
            # networks_lib.NormalTanhDistribution(num_dimensions),
            networks_lib.NormalTanhDistribution(
                num_dimensions,
                w_init=dist_w_init,
                b_init=dist_b_init,
                min_scale=1e-2,
            ),
        ])
        return network(obs)
Ejemplo n.º 2
0
 def _actor_fn(obs):
     network = hk.Sequential([
         hk.nets.MLP(list(hidden_layer_sizes),
                     w_init=hk.initializers.VarianceScaling(
                         1.0, 'fan_in', 'uniform'),
                     activation=jax.nn.relu,
                     activate_final=True),
         networks_lib.NormalTanhDistribution(num_dimensions),
     ])
     return network(obs)
Ejemplo n.º 3
0
 def _policy_fn(obs: jnp.ndarray) -> jnp.ndarray:
   network = hk.Sequential([
       hk.nets.MLP(
           list(policy_layer_sizes),
           w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
           activation=activation,
           activate_final=True),
       networks_lib.NormalTanhDistribution(num_actions),
   ])
   return network(obs)
Ejemplo n.º 4
0
 def _actor_fn(obs):
     # # for matching Ilya's codebase
     # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5)
     # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2)
     x = obs
     for hid_dim in actor_hidden_layer_sizes:
         x = hk.Linear(hid_dim, w_init=w_init, b_init=b_init)(x)
         x = jax.nn.relu(x)
     dist = networks_lib.NormalTanhDistribution(num_dimensions,
                                                w_init=dist_w_init,
                                                b_init=dist_b_init)(x)
     return dist
Ejemplo n.º 5
0
 def _actor_fn(obs, is_training=False, key=None):
     # is_training and key allows to defined train/test dependant modules
     # like dropout.
     del is_training
     del key
     if discrete_actions:
         network = hk.nets.MLP([64, 64, final_layer_size])
     else:
         network = hk.Sequential([
             networks_lib.LayerNormMLP([64, 64], activate_final=True),
             networks_lib.NormalTanhDistribution(final_layer_size),
         ])
     return network(obs)
Ejemplo n.º 6
0
    def _actor_fn(obs):
        # # for matching Ilya's codebase
        # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5)
        # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2)
        # x = obs
        # for hid_dim in actor_hidden_layer_sizes:
        #   x = hk.Linear(hid_dim, w_init=relu_orthogonal, b_init=jnp.zeros)(x)
        #   x = jax.nn.relu(x)
        # dist = networks_lib.NormalTanhDistribution(
        #     num_dimensions,
        #     w_init=near_zero_orthogonal,
        #     b_init=jnp.zeros)(x)
        # return dist

        # w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform')
        # b_init = jnp.zeros

        # PAPER VERSION
        network = hk.Sequential([
            hk.nets.MLP(
                list(actor_hidden_layer_sizes),
                # w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
                # w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"),
                w_init=w_init,
                b_init=b_init,
                activation=jax.nn.relu,
                # activation=jax.nn.tanh,
                activate_final=True),
            networks_lib.NormalTanhDistribution(
                num_dimensions,
                w_init=dist_w_init,
                b_init=dist_b_init,
                min_scale=1e-2,
            ),
            # networks_lib.MultivariateNormalDiagHead(
            #     num_dimensions,
            #     w_init=w_init,
            #     b_init=b_init),
            # networks_lib.GaussianMixture(
            #     num_dimensions,
            #     num_components=5,
            #     multivariate=True),
            # hk.Linear(
            #     NUM_MIXTURE_COMPONENTS + 2 * NUM_MIXTURE_COMPONENTS * num_dimensions,
            #     with_bias=True,
            #     w_init=dist_w_init,
            #     b_init=dist_b_init,),
        ])
        return network(obs)
Ejemplo n.º 7
0
  def _actor_fn(obs):
    w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")
    b_init = jnp.zeros
    dist_w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")
    dist_b_init = jnp.zeros

    network = hk.Sequential([
        hk.nets.MLP(
            list(actor_hidden_layer_sizes),
            w_init=w_init,
            b_init=b_init,
            activation=jax.nn.relu,
            activate_final=True),
        networks_lib.NormalTanhDistribution(
            num_dimensions,
            w_init=dist_w_init,
            b_init=dist_b_init),
    ])
    return network(obs)