Ejemplo n.º 1
0
def func_boxspace(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    mu = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    logvar = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}
Ejemplo n.º 2
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm_m = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_v = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_m = partial(batch_norm_m, is_training=is_training)
     batch_norm_v = partial(batch_norm_v, is_training=is_training)
     mu = hk.Sequential((
         hk.Linear(7),
         batch_norm_m,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(self.env_boxspace.action_space.shape)),
         hk.Reshape(self.env_boxspace.action_space.shape),
     ))
     logvar = hk.Sequential((
         hk.Linear(7),
         batch_norm_v,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(self.env_boxspace.action_space.shape)),
         hk.Reshape(self.env_boxspace.action_space.shape),
     ))
     return {'mu': mu(flatten(S)), 'logvar': logvar(flatten(S))}
Ejemplo n.º 3
0
 def func(S, is_training):
     env = self.env_discrete
     output_shape = (env.action_space.n, *env.observation_space.shape)
     flatten = hk.Flatten()
     batch_norm_m = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_v = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_m = partial(batch_norm_m, is_training=is_training)
     batch_norm_v = partial(batch_norm_v, is_training=is_training)
     mu = hk.Sequential((
         hk.Linear(7),
         batch_norm_m,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(output_shape)),
         hk.Reshape(output_shape),
     ))
     logvar = hk.Sequential((
         hk.Linear(7),
         batch_norm_v,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(output_shape)),
         hk.Reshape(output_shape),
     ))
     X = flatten(S)
     return {'mu': mu(X), 'logvar': logvar(X)}
Ejemplo n.º 4
0
    def __init__(self, C, position_enc_fn, name=None):
        super().__init__(name=name)
        he_init = hk.initializers.VarianceScaling(scale=2.0)

        channels = C['encoder_cnn_channels']
        kernels  = C['encoder_cnn_kernels']
        strides  = C['encoder_cnn_strides']

        hidden_size = channels[-1]
        self.cnn_layers = hk.Sequential([
            hk.Conv2D(channels[0], kernels[0], stride=strides[0], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
            hk.Conv2D(channels[1], kernels[1], stride=strides[1], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
            hk.Conv2D(channels[2], kernels[2], stride=strides[2], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
            hk.Conv2D(hidden_size, kernels[3], stride=strides[3], padding='SAME', w_init=he_init, with_bias=True), jax.nn.relu,
        
        ])

        self.pos_embed = SoftPositionEmbed(hidden_size, C['hidden_res'], position_enc_fn)

        self.linears = hk.Sequential([ # i.e. 1x1 convolution (shared 32 neurons across all locations)
            hk.Reshape((-1, hidden_size)), # Flatten spatial dim (works with batch)
            hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
            hk.Linear(32, w_init=he_init), jax.nn.relu,
            hk.Linear(32, w_init=he_init),
        ])
Ejemplo n.º 5
0
 def q_net(obs):
     layers_ = tuple(layers) + (onp.prod(output_shape), )
     if use_noisy_network:
         network = NoisyMLP(layers_, factorized_noise=use_factorized_noise)
     else:
         network = hk.nets.MLP(layers_)
     return hk.Reshape(output_shape=output_shape)(network(obs))
Ejemplo n.º 6
0
def func_discrete_type2(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential(
        (hk.Flatten(), hk.Linear(8), jax.nn.relu,
         partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
         partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh,
         hk.Linear(discrete.n * discrete.n),
         hk.Reshape((discrete.n, discrete.n)), jax.nn.softmax))
    return seq(S)
Ejemplo n.º 7
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential(
         (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh,
          hk.Linear(self.env_discrete.action_space.n * 51),
          hk.Reshape((self.env_discrete.action_space.n, 51))))
     return seq(flatten(S))
Ejemplo n.º 8
0
Archivo: sac.py Proyecto: coax-dev/coax
def func_pi(S, is_training):
    seq = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(env.action_space.shape) * 2, w_init=jnp.zeros),
        hk.Reshape((*env.action_space.shape, 2)),
    ))
    x = seq(S)
    mu, logvar = x[..., 0], x[..., 1]
    return {'mu': mu, 'logvar': logvar}
Ejemplo n.º 9
0
def make_conditioner(event_shape: Sequence[int], hidden_sizes: Sequence[int],
                     num_bijector_params: int) -> hk.Sequential:
    """Creates an MLP conditioner for each layer of the flow."""
    return hk.Sequential([
        hk.Flatten(),
        hk.nets.MLP(hidden_sizes, activate_final=True),
        # We initialize this linear layer to zero so that the flow is initialized
        # to the identity function.
        hk.Linear(np.prod(event_shape) * num_bijector_params,
                  w_init=jnp.zeros,
                  b_init=jnp.zeros),
        hk.Reshape(tuple(event_shape) + (num_bijector_params, )),
    ])
Ejemplo n.º 10
0
def func_pi(S, is_training):
    shared = hk.Sequential((
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(8),
        jax.nn.relu,
    ))
    mu = hk.Sequential((
        shared,
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    logvar = hk.Sequential((
        shared,
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}
Ejemplo n.º 11
0
def func_boxspace_type1(S, A, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8), jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    X = jax.vmap(jnp.kron)(S, A)
    return seq(X)
Ejemplo n.º 12
0
def func_type2(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    logits = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(discrete.n * num_bins),
        hk.Reshape((discrete.n, num_bins)),
    ))
    return {'logits': logits(S)}
Ejemplo n.º 13
0
def func_pi(S, is_training):
    seq = hk.Sequential((
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    mu = seq(S)
    return {
        'mu': mu,
        'logvar': jnp.full_like(mu, -10)
    }  # (almost) deterministic
Ejemplo n.º 14
0
def func_quantile_type2(S, is_training):
    """ type-1 q-function: (s,a) -> q(s,a) """
    encoder = hk.Sequential((hk.Flatten(), hk.Linear(8), jax.nn.relu))
    quantile_fractions = quantiles_uniform(
        rng=hk.next_rng_key(),
        batch_size=jax.tree_leaves(S)[0].shape[0],
        num_quantiles=num_bins)
    x = encoder(S)
    quantile_x = quantile_net(x, quantile_fractions=quantile_fractions)
    quantile_values = hk.Sequential(
        (hk.Linear(discrete.n), hk.Reshape(
            (discrete.n, num_bins))))(quantile_x)
    return {
        'values':
        quantile_values,
        'quantile_fractions':
        quantile_fractions[:, None, :].tile([1, discrete.n, 1])
    }
 def __init__(
     self,
     latent_size: int,
     hidden_size: int,
     output_shape: Sequence[int] = MNIST_IMAGE_SHAPE,
 ):
     super().__init__(name="model")
     self._latent_size = latent_size
     self._hidden_size = hidden_size
     self._output_shape = output_shape
     self.generative_network = hk.Sequential([
         hk.Linear(self._hidden_size),
         jax.nn.relu,
         hk.Linear(self._hidden_size),
         jax.nn.relu,
         hk.Linear(np.prod(self._output_shape)),
         hk.Reshape(self._output_shape, preserve_dims=2),
     ])
Ejemplo n.º 16
0
def pi(S, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    shape = env.action_space.shape
    rate = hparams.dropout_actor * is_training
    seq = hk.Sequential((
        hk.Linear(hparams.h1_actor),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        hk.Linear(hparams.h2_actor),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        hk.Linear(hparams.h3_actor),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        hk.Linear(onp.prod(shape)),
        hk.Reshape(shape),
        # lambda x: low + (high - low) * jax.nn.sigmoid(x),  # disable: BoxActionsToReals
    ))
    return seq(S)  # batch of actions
Ejemplo n.º 17
0
 def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution:
     logits = self._linear(inputs)
     if not isinstance(self._logit_shape, int):
         logits = hk.Reshape(self._logit_shape)(logits)
     return tfd.Categorical(logits=logits, dtype=self._dtype)
Ejemplo n.º 18
0
 def initial_state(batch_size: Optional[int] = None):
   network = hk.DeepRNN([hk.Reshape([-1], preserve_dims=1),
                         hk.LSTM(output_size)])
   return network.initial_state(batch_size)
Ejemplo n.º 19
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
   return hk.DeepRNN([hk.Reshape([-1], preserve_dims=1),
                      hk.LSTM(output_size)])(inputs, state)
Ejemplo n.º 20
0
 def q_net(obs):
     layers_ = tuple(layers) + (onp.prod(output_shape), )
     network = NoisyMLP(layers_)
     return hk.Reshape(output_shape=output_shape)(network(obs))