def net_fn(x):
  """Haiku module for our network."""
  net = hk.Sequential([
      hk.Linear(1024),
      jax.nn.relu,
      hk.Linear(1024),
      jax.nn.relu,
      hk.Linear(1024),
      jax.nn.relu,
      hk.Linear(1024),
      jax.nn.relu,
      hk.Linear(NUM_ACTIONS),
      jax.nn.log_softmax,
  ])
  return net(x)
Esempio n. 2
0
  def __init__(self):
    super().__init__()
    bn_config = {'create_scale': True,
                 'create_offset': True,
                 'decay_rate': 0.999}

    # Definition of the modules.
    self.conv_block = hk.Sequential([
        hk.Conv2D(1, (3, 3), stride=3, rate=1), jax.nn.relu,
        hk.Conv2D(1, (3, 3), stride=3, rate=1), jax.nn.relu,
    ])

    self.conv_res_block = hk.Sequential([
        hk.Conv2D(1, (1, 1), stride=1, rate=1), jax.nn.relu,
        hk.Conv2D(1, (1, 1), stride=1, rate=1), jax.nn.relu,
    ])

    self.reshape_mod = hk.Flatten()

    self.lin_res_block = [
        (hk.Linear(16), hk.BatchNorm(name='lin_batchnorm_0', **bn_config))
    ]

    self.final_linear = hk.Linear(10)
Esempio n. 3
0
def func(S, is_training):
    """ type-2 q-function: s -> q(s,.) """
    seq = hk.Sequential((
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4),
        jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(256),
        jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = jnp.stack(S, axis=-1) / 255.  # stack frames
    return seq(X)
Esempio n. 4
0
def func_type1(S, A, is_training):
    seq = hk.Sequential((
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(hk.BatchNorm(False, False, 0.99), is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(1),
        jnp.ravel,
    ))
    S = hk.Flatten()(S)
    A = hk.Flatten()(A)
    X = jnp.concatenate((S, A), axis=-1)
    return seq(X)
Esempio n. 5
0
def func(S, A, is_training):
    """ type-1 q-function: (s,a) -> q(s,a) """
    encoder = hk.Sequential((hk.Flatten(), hk.Linear(layer_size), jax.nn.relu))
    quantile_fractions = coax.utils.quantiles_uniform(
        rng=hk.next_rng_key(),
        batch_size=jax.tree_leaves(S)[0].shape[0],
        num_quantiles=num_quantiles)
    X = jax.vmap(jnp.kron)(S, A)
    x = encoder(X)
    quantile_x = quantile_net(x, quantile_fractions=quantile_fractions)
    quantile_values = hk.Linear(1, w_init=jnp.zeros)(quantile_x)
    return {
        'values': quantile_values.squeeze(axis=-1),
        'quantile_fractions': quantile_fractions
    }
Esempio n. 6
0
 def __call__(
     self,
     timestep: dm_env.TimeStep,
 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
     """Process a batch of observations."""
     torso = hk.Sequential([
         hk.Flatten(),
         hk.Linear(128), jax.nn.relu,
         hk.Linear(64), jax.nn.relu
     ])
     hidden = torso(timestep.observation)
     policy_logits = hk.Linear(self._num_actions)(hidden)
     baseline = hk.Linear(1)(hidden)
     baseline = jnp.squeeze(baseline, axis=-1)
     return policy_logits, baseline
Esempio n. 7
0
 def func(S, A, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential((
         hk.Linear(7),
         batch_norm,
         jnp.tanh,
         hk.Linear(51),
     ))
     print(S.shape, A.shape)
     X = jnp.concatenate((flatten(S), flatten(A)), axis=-1)
     return {'logits': seq(X)}
Esempio n. 8
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential((
         hk.Linear(7),
         batch_norm,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(self.env_discrete.action_space.n),
     ))
     return seq(flatten(S))
Esempio n. 9
0
    def __init__(self, C, position_enc_fn, name='SlotAttDecoder'):
        super().__init__(name=name)
        self.C = C
        self.num_slots = C['slots']
        channels, kernels, strides = C['decoder_cnn_channels'], C['decoder_cnn_kernels'], C['decoder_cnn_strides']

        deconv_layers = [
           	hk.Conv2DTranspose(channels[0], kernels[0], stride=strides[0], padding='SAME'), jax.nn.relu,
            hk.Conv2DTranspose(channels[1], kernels[1], stride=strides[1], padding='SAME'), jax.nn.relu,
            hk.Conv2DTranspose(channels[2], kernels[2], stride=strides[2], padding='SAME'), jax.nn.relu,
            hk.Conv2DTranspose(4, kernels[3], stride=strides[3]),
        ]

        self.deconvolutions = hk.Sequential(deconv_layers)

        self.pos_embed = SoftPositionEmbed(C['slot_size'], C['spatial_broadcast_dims'], position_enc_fn)
Esempio n. 10
0
 def __call__(self, x):
     # cnn hyperparameters are described in the METHODS of
     # Mnih et al.. Human-level control through deep reinforcement learning, 2015
     # https://www.nature.com/articles/nature14236
     cnn = hk.Sequential([
         hk.Conv2D(32, 8, 4),
         jax.nn.relu,
         hk.Conv2D(64, 4, 2),
         jax.nn.relu,
         hk.Conv2D(64, 3, 1),
         jax.nn.relu,
         hk.Linear(self.hidden_size),
         jax.nn.relu,
         hk.Linear(self.n_actions)
         ])
     return cnn(x)
Esempio n. 11
0
def func_pi(S, is_training):
    seq = hk.Sequential((
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    mu = seq(S)
    return {
        'mu': mu,
        'logvar': jnp.full_like(mu, -10)
    }  # (almost) deterministic
Esempio n. 12
0
def _build_mlp(
    name: str,
    output_sizes: Sequence[int],
    use_layer_norm=False,
    activation=jax.nn.relu,
):
    """Builds an MLP, optionally with layernorm."""
    net = hk.nets.MLP(output_sizes=output_sizes,
                      name=name + "_mlp",
                      activation=activation)
    if use_layer_norm:
        layer_norm = hk.LayerNorm(axis=-1,
                                  create_scale=True,
                                  create_offset=True,
                                  name=name + "_layer_norm")
        net = hk.Sequential([net, layer_norm])
    return jraph.concatenated_args(net)
Esempio n. 13
0
def haiku_model(x, dense_kernel_size=64, max_conv_size=256, num_classes=10):
    layers = []
    for i, c in enumerate([32, 64, 128, 256]):
        c = min(c, max_conv_size)
        layers.append(
            hk.Conv2D(output_channels=c,
                      kernel_shape=3,
                      stride=2,
                      name="conv%d_%d" % (i, c)))
        layers.append(jax.nn.gelu)
    layers += [
        global_spatial_mean_pooling,
        hk.Linear(dense_kernel_size, name="dense_%d" % dense_kernel_size),
        jax.nn.gelu,
        hk.Linear(num_classes, name='logits')
    ]
    return hk.Sequential(layers)(x)
Esempio n. 14
0
def func(S, is_training):
    """ type-2 q-function: s -> q(s,.) """
    seq = hk.Sequential((
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4),
        jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(256),
        jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = jnp.moveaxis(
        S / 255., 1,
        -1)  # shape: (batch, frames, h, w) --> (batch, h, w, frames)
    return seq(X)
Esempio n. 15
0
 def func(S, A, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential((
         hk.Linear(7),
         batch_norm,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(1),
         jnp.ravel,
     ))
     X = jnp.concatenate((flatten(S), flatten(A)), axis=-1)
     return seq(X)
 def __init__(
     self,
     latent_size: int,
     hidden_size: int,
     output_shape: Sequence[int] = MNIST_IMAGE_SHAPE,
 ):
     super().__init__(name="model")
     self._latent_size = latent_size
     self._hidden_size = hidden_size
     self._output_shape = output_shape
     self.generative_network = hk.Sequential([
         hk.Linear(self._hidden_size),
         jax.nn.relu,
         hk.Linear(self._hidden_size),
         jax.nn.relu,
         hk.Linear(np.prod(self._output_shape)),
         hk.Reshape(self._output_shape, preserve_dims=2),
     ])
Esempio n. 17
0
def q(S, A, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    rate = hparams.dropout_critic * is_training
    seq = hk.Sequential((
        hk.Linear(hparams.h1_critic),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        hk.Linear(hparams.h2_critic),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        hk.Linear(hparams.h3_critic),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        hk.Linear(1),
        jnp.ravel,
    ))
    flatten = hk.Flatten()
    X_sa = jnp.concatenate([flatten(S), jnp.tanh(flatten(A))], axis=1)
    return seq(X_sa)
Esempio n. 18
0
    def net_fn(inputs):
        """
        Function representing dense torso for a DQN Q-network.
        """
        network = hk.Sequential([
            # Standardization
            z_norm(),
            hk.Flatten(),

            # Latent space construction
            linear(512, with_bias=True),
            jax.nn.relu,
            linear(512, with_bias=True),
            jax.nn.relu,
            linear(256, with_bias=True),
            jax.nn.relu,
            hk.Flatten(),
        ])
        return network(inputs)
Esempio n. 19
0
    def __call__(self, image, debug=False): 
        """
        if debug, then print activation shapes
        """
        # TODO: output should have length self.n_classes
#        conv_layers = self.depth * [hk.Conv2D(self.n_channels,
#                                              kernel_shape=3,
#                                              w_init=self.initializer,
#                                              b_init=self.initializer,
#                                              stride=2),
#                                    jax.nn.relu]
#        convnet = hk.Sequential(conv_layers + [hk.Flatten()])

        with_bias = False
        strides = [1,2,1,2,1,2]
        names = ['misc'] + ['conv']*5
        
        conv_layers = [
            [
                hk.Conv2D(self.n_channels,
                        kernel_shape=3,
                        w_init=self.initializer,
                        b_init=self.initializer,
                        with_bias=with_bias,
                        stride=stride,
                        name=name),
                jax.nn.relu,
                debug_layer(debug),
            ]
            for stride, name in zip(strides, names)
        ]

        conv_layers = [l for layer in conv_layers for l in layer]
        convnet = hk.Sequential(conv_layers + [
            hk.Flatten(),
            hk.Linear(self.n_classes,
                      w_init=self.initializer,
                      b_init=self.initializer,
                      name='misc'),
            debug_layer(debug),
        ])

        return convnet(image)
Esempio n. 20
0
def pi(S, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    shape = env.action_space.shape
    rate = hparams.dropout_actor * is_training
    seq = hk.Sequential((
        hk.Linear(hparams.h1_actor),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        hk.Linear(hparams.h2_actor),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        hk.Linear(hparams.h3_actor),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        hk.Linear(onp.prod(shape)),
        hk.Reshape(shape),
        # lambda x: low + (high - low) * jax.nn.sigmoid(x),  # disable: BoxActionsToReals
    ))
    return seq(S)  # batch of actions
Esempio n. 21
0
  def _actor_fn(obs):
    w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")
    b_init = jnp.zeros
    dist_w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")
    dist_b_init = jnp.zeros

    network = hk.Sequential([
        hk.nets.MLP(
            list(actor_hidden_layer_sizes),
            w_init=w_init,
            b_init=b_init,
            activation=jax.nn.relu,
            activate_final=True),
        networks_lib.NormalTanhDistribution(
            num_dimensions,
            w_init=dist_w_init,
            b_init=dist_b_init),
    ])
    return network(obs)
Esempio n. 22
0
  def __init__(self, core, memory_size, capacity, hidden_layers, alpha, beta,
               loss_func=(lambda x, y: 0.5 * jnp.square(x - y)),
               apply_core_to_input=False, name="synthetic_returns_wrapper"):
    """Constructor.

    Args:
      core: hk.RNNCore. The recurrent core of the agent. E.g. an LSTM.
      memory_size: Integer. The size of the vectors to be stored in the episodic
          memory.
      capacity: Integer. The maximum number of memories to store before it
          becomes necessary to overwrite old memories.
      hidden_layers: Tuple or list of integers, indicating the size of the
          hidden layers of the MLPs used to produce synthetic returns, current
          state bias, and gate.
      alpha: The multiplier of the synthetic returns term in the augmented
          return.
      beta: The multiplier of the environment returns term in the augmented
          return.
      loss_func: A function of two arguments (predictions and targets) to
          compute the SR loss.
      apply_core_to_input: Boolean. Whether to apply the core on the inputs. If
          true, the synthetic returns will be computed from the outputs of the
          RNN core passed to the constructor. If false, the RNN core will be
          applied only at the output of this wrapper, and the synthetic returns
          will be computed from the inputs.
      name: String. A name for this Haiku module instance.
    """
    super().__init__(name=name)
    self._em = EpisodicMemory(memory_size, capacity)
    self._capacity = capacity
    hidden_layers = list(hidden_layers)
    self._synthetic_return = hk.nets.MLP(hidden_layers + [1])
    self._bias = hk.nets.MLP(hidden_layers + [1])
    self._gate = hk.Sequential([
        hk.nets.MLP(hidden_layers + [1]),
        jax.nn.sigmoid,
    ])
    self._apply_core_to_input = apply_core_to_input
    self._core = core
    self._alpha = alpha
    self._beta = beta
    self._loss = loss_func
Esempio n. 23
0
    def lenet_fn(batch, is_training):
        """Network inspired by LeNet-5."""
        x, _ = batch

        cnn = hk.Sequential([
            hk.Conv2D(output_channels=6, kernel_shape=5, padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
            hk.Conv2D(output_channels=16, kernel_shape=5, padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
            hk.Conv2D(output_channels=120, kernel_shape=5, padding="SAME"),
            jax.nn.relu,
            hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
            hk.Flatten(),
            hk.Linear(84),
            jax.nn.relu,
            hk.Linear(num_classes),
        ])
        return cnn(x)
Esempio n. 24
0
    def __init__(self,
                 layer_sizes: Sequence[int],
                 activate_final: bool = False):
        """Construct the MLP.

    Args:
      layer_sizes: a sequence of ints specifying the size of each layer.
      activate_final: whether or not to use the activation function on the final
        layer of the neural network.
    """
        super().__init__(name='feedforward_mlp_torso')

        self._network = hk.Sequential([
            hk.Linear(layer_sizes[0], w_init=uniform_initializer),
            hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
            jax.lax.tanh,
            hk.nets.MLP(layer_sizes[1:],
                        w_init=uniform_initializer,
                        activation=jax.nn.elu,
                        activate_final=activate_final),
        ])
    def __call__(self, encoder_outputs, inputs):
        dimensionality = self.dimensionality
        results = inputs
        results += MultiHeadAttention(dimensionality=dimensionality,
                                      num_heads=self.num_heads,
                                      causal_mask=True)(results)
        results = LayerNorm()(results)

        results += MultiHeadAttention(dimensionality=dimensionality,
                                      num_heads=self.num_heads)(
                                          inputs=encoder_outputs,
                                          Q_inputs=results)
        results = LayerNorm()(results)

        ff = hk.Sequential([
            hk.Linear(dimensionality), jax.nn.relu,
            hk.Linear(dimensionality)
        ])
        results += ff(results)
        results = LayerNorm()(results)
        return results
Esempio n. 26
0
def lenet_fn(batch):
    """Network inspired by LeNet-5."""
    x, _ = batch
    x = x.astype(jnp.float32)
    cnn = hk.Sequential([
        hk.Conv2D(output_channels=32, kernel_shape=5, padding="SAME"),
        jax.nn.relu,
        hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
        hk.Conv2D(output_channels=64, kernel_shape=5, padding="SAME"),
        jax.nn.relu,
        hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
        hk.Conv2D(output_channels=128, kernel_shape=5, padding="SAME"),
        hk.MaxPool(window_shape=3, strides=2, padding="VALID"),
        hk.Flatten(),
        hk.Linear(1000),
        jax.nn.relu,
        hk.Linear(1000),
        jax.nn.relu,
        hk.Linear(10),
    ])
    return cnn(x)
Esempio n. 27
0
def func(S, is_training):
    rng1, rng2, rng3 = hk.next_rng_keys(3)
    rate = 0.25 if is_training else 0.
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, rng1, rate),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, rng2, rate),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, rng3, rate),
        partial(batch_norm, is_training=is_training),
        hk.Linear(1, w_init=jnp.zeros),
        jnp.ravel,
    ))
    return seq(S)
Esempio n. 28
0
 def large_critic(x):
   # inspired by the ones used in RL Unplugged
   x = hk.Sequential([
       hk.Linear(400, w_init=rlu_uniform_initializer),
       hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
       jax.lax.tanh,])(x)
   x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x)
   for i in range(4):
     x = network_utils.ResidualLayerNormBlock(
         [1024, 1024],
         activation=jax.nn.relu,
         w_init=rlu_uniform_initializer,)(x)
   h = x
   # v = hk.Linear(1, w_init=rlu_uniform_initializer)(h)
   # v = hk.Linear(critic_output_dim)(h)
   all_vs = []
   for _ in range(critic_output_dim):
     head_v = hk.Linear(256, w_init=rlu_uniform_initializer)(h)
     head_v = jax.nn.relu(head_v)
     head_v = hk.Linear(1, w_init=rlu_uniform_initializer)(head_v)
     all_vs.append(head_v)
   v = jnp.concatenate(all_vs, axis=-1)
   return v, h
Esempio n. 29
0
    def __init__(self,
                 n_in: int,
                 n_out: int,
                 n_layers: int,
                 name: str = "EncoderBlock"):
        super().__init__(name=name)
        n_hid = n_out // 4
        self.post_gain = 1 / (n_layers**2)

        self.id_path = hk.Conv2D(
            n_out, 1, name="id_path",
            data_format="NCHW") if n_in != n_out else lambda x: x

        with hk.experimental.name_scope("res_path"):
            self.res_path = hk.Sequential([
                jax.nn.relu,
                hk.Conv2D(n_hid, 1, name="conv_1",
                          data_format="NCHW"), jax.nn.relu,
                hk.Conv2D(n_hid, 3, name="conv_2",
                          data_format="NCHW"), jax.nn.relu,
                hk.Conv2D(n_hid, 3, name="conv_3", data_format="NCHW"),
                jax.nn.relu,
                hk.Conv2D(n_out, 3, name="conv_4", data_format="NCHW")
            ])
Esempio n. 30
0
def func_type2(S, is_training):
    # custom haiku function: s -> q(s,.)
    value = hk.Sequential([...])
    return value(S)  # output shape: (batch_size, num_actions)