Example #1
0
  def __init__(self):
    super().__init__()
    # Definition of the modules.
    self.reshape_mod = hk.Flatten()

    self.lin_block = hk.Sequential([
        hk.Linear(20), jax.nn.relu,
    ])

    self.final_linear = hk.Linear(10)
Example #2
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential(
         (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh,
          hk.Linear(1), jnp.ravel))
     return seq(flatten(S))
def net_fn(x):
    """Haiku module for our network."""
    layers = []
    for layer_size in FLAGS.hidden_layer_sizes:
        layers.append(hk.Linear(int(layer_size)))
        layers.append(jax.nn.relu)
    layers.append(hk.Linear(NUM_ACTIONS))
    layers.append(jax.nn.log_softmax)
    net = hk.Sequential(layers)
    return net(x)
Example #4
0
def mlp(features, phi, x):
    for feat in features:
        d = hk.Linear(feat,
                      with_bias=False,
                      w_init=hk.initializers.RandomNormal())
        x = phi(d(x) / x.shape[-1]**0.5)

    d = hk.Linear(1, with_bias=False, w_init=hk.initializers.RandomNormal())
    x = d(x) / x.shape[-1]
    return x[..., 0]
Example #5
0
 def forward_pass(batch):
   network = hk.Sequential([
       hk.Flatten(),
       hk.Linear(hidden_units),
       jax.nn.relu,
       hk.Linear(hidden_units),
       jax.nn.relu,
       hk.Linear(num_classes),
   ])
   return network(batch['x'])
Example #6
0
    def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        x = hk.Flatten()(x)
        x = hk.Linear(self._hidden_size)(x)
        x = jax.nn.relu(x)

        mean = hk.Linear(self._latent_size)(x)
        log_stddev = hk.Linear(self._latent_size)(x)
        stddev = jnp.exp(log_stddev)

        return mean, stddev
def _single_trunk_model(x):
    # input (64, 64, 13)
    num_classes = 10
    return hk.Sequential(
        [conv(32), gelu,                      # (32, 32, 32)
         conv(64), gelu,                      # (16, 16, 64)
         conv(128), gelu,                     # (8, 8, 128)
         global_spatial_mean_pooling,         # (128)
         hk.Linear(32), gelu,                 # (32)
         hk.Linear(num_classes)])(x)          # (10)
Example #8
0
 def __init__(self,
              latent_dim,
              layers,
              units):
     super(GenDynamics, self).__init__()
     self.latent_dim = latent_dim
     self.model = hk.Sequential([unit for _ in range(layers + 1) for unit in
                                 [jnp.tanh, hk.Linear(units)]] +
                                [jnp.tanh, hk.Linear(latent_dim)]
                                )
Example #9
0
def net_fn(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Graph net function."""
    # Add a global paramater for graph classification.
    graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))
    embedder = jraph.GraphMapFeatures(hk.Linear(128), hk.Linear(128),
                                      hk.Linear(128))
    net = jraph.GraphNetwork(update_node_fn=node_update_fn,
                             update_edge_fn=edge_update_fn,
                             update_global_fn=update_global_fn)
    return net(embedder(graph))
Example #10
0
def net_fn(batch: Batch) -> jnp.ndarray:
  """Standard LeNet-300-100 MLP network."""
  x = batch["image"].astype(jnp.float32) / 255.
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  return mlp(x)
Example #11
0
 def __call__(self, x):
     x_input = x
     for i, unit in enumerate(self.hidden_units):
         x = hk.Linear(unit, **self.hidden_kwargs)(x)
         x = self.hidden_activation(x)
         if self.d2rl and i + 1 != len(self.hidden_units):
             x = jnp.concatenate([x, x_input], axis=1)
     x = hk.Linear(self.output_dim, **self.output_kwargs)(x)
     if self.output_activation is not None:
         x = self.output_activation(x)
     return x
Example #12
0
 def __init__(self, is_training=True):
     super().__init__()
     self.is_training = is_training
     self.encoder = TokenEncoder(FLAGS.vocab_size, FLAGS.duration_lstm_dim,
                                 FLAGS.duration_embed_dropout_rate,
                                 is_training)
     self.projection = hk.Sequential([
         hk.Linear(FLAGS.duration_lstm_dim),
         jax.nn.gelu,
         hk.Linear(1),
     ])
Example #13
0
  def __init__(self, num_dimensions: int, min_scale: float = 1e-6):
    """Initialization.

    Args:
      num_dimensions: Number of dimensions of MVN distribution.
      min_scale: Minimum standard deviation.
    """
    super().__init__(name='MultivariateNormalDiagHead')
    self._min_scale = min_scale
    self._loc_layer = hk.Linear(num_dimensions)
    self._scale_layer = hk.Linear(num_dimensions)
Example #14
0
def two_layers_net(width: int = 30, output_dim: int = 1) -> hk.Module:
    '''
    A basic two layer network with ReLU activations
    '''
    network = hk.Sequential([
        hk.Linear(width), jax.nn.relu,
        hk.Linear(width), jax.nn.relu,
        hk.Linear(output_dim)
    ])

    return network
Example #15
0
def net_fn(x):
    k = 1024
    mlp = hk.Sequential([
        hk.Linear(k), jax.nn.swish,
        hk.Linear(k), jax.nn.swish,
        hk.Linear(k), jax.nn.swish,
        hk.Linear(k), jax.nn.swish,
        hk.Linear(16)
    ])
    E = jnp.eye(4) + .05 * mlp(x).reshape(4, 4)
    return Eto_interpedg(E, x)
Example #16
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential(
         (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh,
          hk.Linear(self.env_discrete.action_space.n * 51),
          hk.Reshape((self.env_discrete.action_space.n, 51))))
     return seq(flatten(S))
Example #17
0
def model(x, dense_kernel_size=64, max_conv_size=256, num_classes=10):
    layers = []
    for c in [32, 64, 128, 256]:
        layers.append(hk.Conv2D(output_channels=min(c, max_conv_size),
                                kernel_shape=3, stride=2))
        layers.append(gelu)
    layers += [global_spatial_mean_pooling,
               hk.Linear(dense_kernel_size),
               gelu,
               hk.Linear(num_classes)]
    return hk.Sequential(layers)(x)
Example #18
0
 def _first_derivative(self, x: Array, t: float) -> Array:
     intermediate_dims = 3 * self.output_size * x.shape[-1]
     mlp = hk.Sequential([
         hk.Flatten(),
         hk.Linear(intermediate_dims), jax.nn.sigmoid,
         hk.Linear(intermediate_dims), jax.nn.sigmoid,
         hk.Linear(self.output_size * x.shape[-1],
                   w_init=hk.initializers.Constant(0.),
                   b_init=hk.initializers.Constant(0.))
     ])
     return mlp(jnp.append(x, t)).reshape(self.output_size, x.shape[-1])
Example #19
0
    def __call__(
        self,
        X: jnp.ndarray,
    ) -> jnp.ndarray:
        layers = []
        for d_o in self.hidden_dims:
            layers.append(hk.Linear(d_o))
            layers.append(jnp.tanh)
        layers.append(hk.Linear(self.output_dim))

        return hk.Sequential(layers)(X)
Example #20
0
File: sac.py Project: coax-dev/coax
def func_pi(S, is_training):
    seq = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(env.action_space.shape) * 2, w_init=jnp.zeros),
        hk.Reshape((*env.action_space.shape, 2)),
    ))
    x = seq(S)
    mu, logvar = x[..., 0], x[..., 1]
    return {'mu': mu, 'logvar': logvar}
Example #21
0
 def __call__(self, x: Array, t: float) -> Array:
     intermediate_dims = 3 * self.output_size
     mlp = hk.Sequential([
         hk.Flatten(),
         hk.Linear(intermediate_dims), jax.nn.sigmoid,
         hk.Linear(intermediate_dims), jax.nn.sigmoid,
         hk.Linear(self.output_size,
                   w_init=hk.initializers.Constant(0.),
                   b_init=hk.initializers.Constant(0.))
     ])
     return mlp(jnp.append(x, t))
Example #22
0
 def __call__(self, x: dm_env.TimeStep, state):
     torso_net = hk.Sequential([
         hk.Flatten(),
         hk.Linear(128), jax.nn.relu,
         hk.Linear(64), jax.nn.relu
     ])
     torso_output = torso_net(x.observation)
     policy_logits = hk.Linear(self._num_actions)(torso_output)
     value = hk.Linear(1)(torso_output)
     value = jnp.squeeze(value, axis=-1)
     return NetOutput(policy_logits=policy_logits, value=value), state
Example #23
0
def func(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    logits = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8), jnp.tanh,
        hk.Linear(num_bins),
    ))
    return {'logits': logits(S)}
 def _critic_fn(obs):
   preds = []
   for _ in range(num_critics):
     layers = [
         hk.Linear(256),
         jax.nn.relu,
         hk.Linear(256),
         jax.nn.relu,
         hk.Linear(num_dimensions),
     ]
     preds.append(hk.Sequential(layers)(obs))
   return jnp.stack(preds, axis=-1)
Example #25
0
def func_type2(S, is_training):
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(hk.BatchNorm(False, False, 0.99), is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(env_discrete.action_space.n),
    ))
    return seq(S)
Example #26
0
def func_boxspace_type2(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8), jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape) * discrete.n),
        hk.Reshape((discrete.n, *boxspace.shape)),
    ))
    return seq(S)
 def __init__(self, latent_size: int, hidden_size: int):
     super().__init__(name="variational")
     self.encoder = hk.Sequential([
         hk.Flatten(),
         hk.Linear(hidden_size),
         jax.nn.relu,
         hk.Linear(hidden_size),
         jax.nn.relu,
         hk.Linear(latent_size * 3, w_init=jnp.zeros, b_init=jnp.zeros),
     ])
     self.first_block = InverseAutoregressiveFlow(latent_size, hidden_size)
     self.second_block = InverseAutoregressiveFlow(latent_size, hidden_size)
Example #28
0
 def net_fn(batch: Batch) -> jnp.ndarray:
     """Standard MLP network."""
     x = batch["image"].astype(jnp.float32) / 255.0
     mlp = hk.Sequential([
         hk.Flatten(),
         hk.Linear(n_units_l1),
         jax.nn.relu,
         hk.Linear(n_units_l2),
         jax.nn.relu,
         hk.Linear(10),
     ])
     return mlp(x)
 def __init__(self, latent_size: int, hidden_size: int):
     super().__init__(name="variational")
     self._latent_size = latent_size
     self._hidden_size = hidden_size
     self.inference_network = hk.Sequential([
         hk.Flatten(),
         hk.Linear(self._hidden_size),
         jax.nn.relu,
         hk.Linear(self._hidden_size),
         jax.nn.relu,
         hk.Linear(self._latent_size * 2),
     ])
Example #30
0
 def __call__(self, x: Array, t: float) -> Array:
     intermediate_dims = 3 * self.output_size**2
     mlp = hk.Sequential([
         hk.Flatten(),
         hk.Linear(intermediate_dims), jax.nn.sigmoid,
         hk.Linear(intermediate_dims), jax.nn.sigmoid,
         hk.Linear(self.output_size**2,
                   w_init=hk.initializers.Constant(1.),
                   b_init=hk.initializers.Constant(1.))
     ])
     output = jnp.reshape(mlp(x), (self.output_size, self.output_size))
     return aux_math.matrix_diag_transform(output, jax.nn.softplus)