Exemple #1
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm_m = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_v = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_m = partial(batch_norm_m, is_training=is_training)
     batch_norm_v = partial(batch_norm_v, is_training=is_training)
     mu = hk.Sequential((
         hk.Linear(7),
         batch_norm_m,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(self.env_boxspace.action_space.shape)),
         hk.Reshape(self.env_boxspace.action_space.shape),
     ))
     logvar = hk.Sequential((
         hk.Linear(7),
         batch_norm_v,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(self.env_boxspace.action_space.shape)),
         hk.Reshape(self.env_boxspace.action_space.shape),
     ))
     return {'mu': mu(flatten(S)), 'logvar': logvar(flatten(S))}
Exemple #2
0
 def func(S, is_training):
     env = self.env_discrete
     output_shape = (env.action_space.n, *env.observation_space.shape)
     flatten = hk.Flatten()
     batch_norm_m = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_v = hk.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.95)
     batch_norm_m = partial(batch_norm_m, is_training=is_training)
     batch_norm_v = partial(batch_norm_v, is_training=is_training)
     mu = hk.Sequential((
         hk.Linear(7),
         batch_norm_m,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(output_shape)),
         hk.Reshape(output_shape),
     ))
     logvar = hk.Sequential((
         hk.Linear(7),
         batch_norm_v,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(onp.prod(output_shape)),
         hk.Reshape(output_shape),
     ))
     X = flatten(S)
     return {'mu': mu(X), 'logvar': logvar(X)}
Exemple #3
0
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()

        self.conv1 = hk.Conv2D(output_channels=planes,
                               kernel_shape=3,
                               stride=stride,
                               padding='SAME',
                               with_bias=False,
                               data_format='NCHW')
        self.bn1 = hk.BatchNorm(create_scale=True,
                                create_offset=True,
                                decay_rate=0.9,
                                data_format='NC...')

        self.conv2 = hk.Conv2D(output_channels=planes,
                               kernel_shape=3,
                               stride=1,
                               padding='SAME',
                               with_bias=False,
                               data_format='NCHW')
        self.bn2 = hk.BatchNorm(create_scale=True,
                                create_offset=True,
                                decay_rate=0.9,
                                data_format='NC...')

        self.in_planes = in_planes
        self.planes = planes
        self.stride = stride
Exemple #4
0
 def __init__(self, vocab_size, lstm_dim, dropout_rate, is_training=True):
   super().__init__()
   self.is_training = is_training
   self.embed = hk.Embed(vocab_size, lstm_dim)
   self.conv1 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.conv2 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.conv3 = hk.Conv1D(lstm_dim, 3, padding='SAME')
   self.bn1 = hk.BatchNorm(True, True, 0.9)
   self.bn2 = hk.BatchNorm(True, True, 0.9)
   self.bn3 = hk.BatchNorm(True, True, 0.9)
   self.lstm_fwd = hk.LSTM(lstm_dim)
   self.lstm_bwd = hk.ResetCore(hk.LSTM(lstm_dim))
   self.dropout_rate = dropout_rate
Exemple #5
0
    def __call__(self, x, key, dropout_rates, is_training):
        out = hk.Flatten()(x)

        if is_training and (dropout_rates is not None):
            keys = jax.random.split(key, self.nlayers)

        for l in range(self.nlayers):
            out = hk.Linear(self.nhid, with_bias=self.with_bias)(out)
            if self.batch_norm:
                out = hk.BatchNorm(create_scale=False,
                                   create_offset=False,
                                   decay_rate=0.9)(out, is_training)

            if is_training and (dropout_rates is not None):
                out = dropout(keys[l], dropout_rates[l], out)

            if self.activation == 'relu':
                out = jax.nn.relu(out)
            elif self.activation == 'sigmoid':
                out = jax.nn.sigmoid(out)
            elif self.activation == 'tanh':
                out = jnp.tanh(out)
            elif self.activation == 'linear':
                out = out

        out = hk.Linear(10, with_bias=self.with_bias)(out)
        return out
Exemple #6
0
    def __init__(self):
        super().__init__()
        bn_config = {
            'create_scale': True,
            'create_offset': True,
            'decay_rate': 0.999
        }

        # Definition of the modules.
        self.conv_block = hk.Sequential([
            hk.Conv2D(1, (3, 3), stride=3, rate=1),
            jax.nn.relu,
            hk.Conv2D(1, (3, 3), stride=3, rate=1),
            jax.nn.relu,
        ])

        self.conv_res_block = hk.Sequential([
            hk.Conv2D(1, (1, 1), stride=1, rate=1),
            jax.nn.relu,
            hk.Conv2D(1, (1, 1), stride=1, rate=1),
            jax.nn.relu,
        ])

        self.reshape_mod = hk.Flatten()

        self.lin_res_block = [(hk.Linear(16),
                               hk.BatchNorm(name='lin_batchnorm_0',
                                            **bn_config))]

        self.final_linear = hk.Linear(10)
Exemple #7
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Update node features.

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch of graphs.
        adj : ndarray of shape (batch_size, N, N)
            Batch adjacency matrix.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (batch_size, N, out_feats)
            Batch new node features.
        """
        dropout = self.dropout if is_training is True else 0.0

        # for batch data
        new_node_feats = jax.vmap(self._update_nodes)(node_feats, adj)
        if self.bias:
            new_node_feats += self.b
        new_node_feats = self.activation(new_node_feats)

        if dropout != 0.0:
            new_node_feats = hk.dropout(hk.next_rng_key(), dropout,
                                        new_node_feats)
        if self.batch_norm:
            new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats,
                                                           is_training)

        return new_node_feats
Exemple #8
0
 def __call__(self, inputs, is_training):
     dropout_rate = self._dropout_rate if is_training else 0.0
     h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs)
     h = hk.Linear(self._vocab_size, with_bias=False)(h)
     return hk.BatchNorm(create_scale=False,
                         create_offset=False,
                         decay_rate=0.9)(h, is_training)
    def __call__(self, x, mask_props=None, is_training=True):
        out = hk.Flatten()(x)

        for l in range(self.nlayers):
            out = hk.Linear(self.nhid, with_bias=self.with_bias)(out)
            if self.batch_norm:
                out = hk.BatchNorm(create_scale=False,
                                   create_offset=False,
                                   decay_rate=0.9)(out, is_training)

            if mask_props is not None:
                num_units = jnp.floor(mask_props[l] * out.shape[1]).astype(
                    jnp.int32)
                mask = jnp.arange(out.shape[1]) < num_units
                out = jnp.where(mask, out, jnp.zeros(out.shape))

            if self.activation == 'relu':
                out = jax.nn.relu(out)
            elif self.activation == 'sigmoid':
                out = jax.nn.sigmoid(out)
            elif self.activation == 'tanh':
                out = jnp.tanh(out)
            elif self.activation == 'linear':
                out = out

        out = hk.Linear(10, with_bias=self.with_bias)(out)
        return out
Exemple #10
0
def func_boxspace(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    mu = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    logvar = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}
Exemple #11
0
 def single_mlp(inner_name: str):
     """Creates a single MLP performing the update."""
     mlp = hk.nets.MLP(output_sizes=output_sizes,
                       name=inner_name,
                       activation=activation)
     mlp = jraph.concatenated_args(mlp)
     if normalization_type == 'layer_norm':
         norm = hk.LayerNorm(axis=-1,
                             create_scale=True,
                             create_offset=True,
                             name=name + '_layer_norm')
     elif normalization_type == 'batch_norm':
         batch_norm = hk.BatchNorm(
             create_scale=True,
             create_offset=True,
             decay_rate=0.9,
             name=f'{inner_name}_batch_norm',
             cross_replica_axis=None if hk.running_init() else 'i',
         )
         norm = lambda x: batch_norm(x, is_training)
     elif normalization_type == 'none':
         return mlp
     else:
         raise ValueError(
             f'Unknown normalization type {normalization_type}')
     return jraph.concatenated_args(hk.Sequential([mlp, norm]))
Exemple #12
0
 def fn(x):
     if dropout:
         x = hk.dropout(hk.next_rng_key(), 0.5, x)
     if batchnorm:
         x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(
             x, is_training=True
         )
     return x
def func_discrete_type1(S, A, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential(
        (hk.Flatten(), hk.Linear(8), jax.nn.relu,
         partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
         partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh,
         hk.Linear(discrete.n), jax.nn.softmax))
    X = jax.vmap(jnp.kron)(S, A)
    return seq(X)
def func_discrete_type2(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential(
        (hk.Flatten(), hk.Linear(8), jax.nn.relu,
         partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
         partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh,
         hk.Linear(discrete.n * discrete.n),
         hk.Reshape((discrete.n, discrete.n)), jax.nn.softmax))
    return seq(S)
Exemple #15
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential(
         (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh,
          hk.Linear(1), jnp.ravel))
     return seq(flatten(S))
Exemple #16
0
def func(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    logits = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8), jnp.tanh,
        hk.Linear(num_bins),
    ))
    return {'logits': logits(S)}
Exemple #17
0
 def func(S, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential(
         (hk.Linear(7), batch_norm, jnp.tanh, hk.Linear(3), jnp.tanh,
          hk.Linear(self.env_discrete.action_space.n * 51),
          hk.Reshape((self.env_discrete.action_space.n, 51))))
     return seq(flatten(S))
Exemple #18
0
 def __init__(self, c_in, c_out):
     super().__init__()
     self.conv = hk.Conv2D(output_channels=c_out,
                           kernel_shape=3,
                           stride=1,
                           with_bias=False,
                           padding='SAME',
                           data_format='NCHW')
     self.bn = hk.BatchNorm(create_scale=True,
                            create_offset=True,
                            decay_rate=0.9,
                            data_format='NC...')
Exemple #19
0
def _call_layers(
        cfg,
        inp: Tensor,
        batch_norm: bool = True,
        include_top: bool = True,
        initial_weights: Optional[hk.Params] = None,
        output_feature_maps: bool = False) -> Union[Tensor, List[Tensor]]:

    x = inp
    partial_results = []

    # Ignore max pooling if we do not append the classifier
    if not include_top:
        cfg = cfg[:-1]

    i = 0
    base_name = 'vgg16/conv2_d'
    for v in cfg:
        if v == 'M':
            partial_results.append(x)
            x = hk.MaxPool(window_shape=2, strides=2, padding="VALID")(x)
        else:
            if i == 0:
                param_name = base_name
            else:
                param_name = base_name + f'_{i}'

            i += 1

            w_init = (None
                      if initial_weights is None else hk.initializers.Constant(
                          constant=initial_weights[param_name]['w']))
            b_init = (None
                      if initial_weights is None else hk.initializers.Constant(
                          constant=initial_weights[param_name]['b']))
            x = hk.Conv2D(v,
                          kernel_shape=3,
                          stride=1,
                          padding='SAME',
                          w_init=w_init,
                          b_init=b_init)(x)

            if batch_norm:
                x = hk.BatchNorm(True, True, decay_rate=0.999)(x)

            x = jax.nn.relu(x)

    partial_results.append(x)

    if not output_feature_maps:
        return partial_results[-1]
    else:
        return partial_results
Exemple #20
0
def func_type2(S, is_training):
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(hk.BatchNorm(False, False, 0.99), is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(env_discrete.action_space.n),
    ))
    return seq(S)
    def __init__(self,
                 output_channels: Sequence[int],
                 kernel_shapes: Union[int, Sequence[int]] = 3,
                 strides: Union[int, Sequence[int]] = 1,
                 padding: Union[str, Sequence[str]] = "SAME",
                 data_format: str = "NHWC",
                 with_batch_norm: bool = False,
                 activate_final: bool = False,
                 activation: Activation = "leaky_relu",
                 name: Optional[str] = None):
        super().__init__(name=name)
        self.output_channels = tuple(output_channels)
        self.num_layers = len(self.output_channels)
        self.kernel_shapes = utils.bcast_if(kernel_shapes, int,
                                            self.num_layers)
        self.strides = utils.bcast_if(strides, int, self.num_layers)
        self.padding = utils.bcast_if(padding, str, self.num_layers)
        self.data_format = data_format
        self.with_batch_norm = with_batch_norm
        self.activate_final = activate_final
        self.activation = utils.get_activation(activation)

        if len(self.kernel_shapes) != self.num_layers:
            raise ValueError(
                f"Kernel shapes is of size {len(self.kernel_shapes)}, "
                f"while output_channels is of size{self.num_layers}.")
        if len(self.strides) != self.num_layers:
            raise ValueError(
                f"Strides is of size {len(self.kernel_shapes)}, while "
                f"output_channels is of size{self.num_layers}.")
        if len(self.padding) != self.num_layers:
            raise ValueError(f"Padding is of size {len(self.padding)}, while "
                             f"output_channels is of size{self.num_layers}.")

        self.conv_modules = []
        self.bn_modules = []
        for i in range(self.num_layers):
            self.conv_modules.append(
                hk.Conv2D(output_channels=self.output_channels[i],
                          kernel_shape=self.kernel_shapes[i],
                          stride=self.strides[i],
                          padding=self.padding[i],
                          data_format=data_format,
                          name=f"conv_2d_{i}"))
            if with_batch_norm:
                self.bn_modules.append(
                    hk.BatchNorm(create_offset=True,
                                 create_scale=False,
                                 decay_rate=0.999,
                                 name=f"batch_norm_{i}"))
            else:
                self.bn_modules.append(None)
def func_boxspace_type1(S, A, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8), jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8), jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    X = jax.vmap(jnp.kron)(S, A)
    return seq(X)
Exemple #23
0
 def __call__(
     self,
     graph: tp.Union[jnp.ndarray, JAXSparse],
     node_features: jnp.ndarray,
     is_training: tp.Optional[bool] = None,
 ):
     x = node_features
     for f in self.hidden_filters:
         x = GCN(f)(graph, x)
         x = hk.BatchNorm(True, True, 0.9)(x, is_training=is_training)
         x = jax.nn.relu(x)
         x = dropout(x, self.dropout_rate, is_training=is_training)
     logits = GCN(self.num_classes)(graph, x)
     return logits
Exemple #24
0
    def __call__(self, inputs, is_training):
        dropout_rate = self._dropout_rate if is_training else 0.0

        h = jax.nn.softplus(hk.Linear(self._hidden)(inputs))
        h = jax.nn.softplus(hk.Linear(self._hidden)(h))
        h = hk.dropout(hk.next_rng_key(), dropout_rate, h)
        h = hk.Linear(self._num_topics)(h)

        # NB: here we set `create_scale=False` and `create_offset=False` to reduce
        # the number of learning parameters
        log_concentration = hk.BatchNorm(create_scale=False,
                                         create_offset=False,
                                         decay_rate=0.9)(h, is_training)
        return jnp.exp(log_concentration)
Exemple #25
0
 def func(S, A, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential((
         hk.Linear(7),
         batch_norm,
         jnp.tanh,
         hk.Linear(51),
     ))
     print(S.shape, A.shape)
     X = jnp.concatenate((flatten(S), flatten(A)), axis=-1)
     return {'logits': seq(X)}
Exemple #26
0
def func_type1(S, A, is_training):
    seq = hk.Sequential((
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(hk.BatchNorm(False, False, 0.99), is_training=is_training),
        hk.Linear(8),
        jax.nn.relu,
        hk.Linear(1),
        jnp.ravel,
    ))
    S = hk.Flatten()(S)
    A = hk.Flatten()(A)
    X = jnp.concatenate((S, A), axis=-1)
    return seq(X)
Exemple #27
0
  def __init__(self, is_training=True):
    super().__init__()
    self.is_training = is_training
    self.encoder = TokenEncoder(FLAGS.vocab_size, FLAGS.acoustic_encoder_dim, 0.5, is_training)
    self.decoder = hk.deep_rnn_with_skip_connections([
        hk.LSTM(FLAGS.acoustic_decoder_dim),
        hk.LSTM(FLAGS.acoustic_decoder_dim)
    ])
    self.projection = hk.Linear(FLAGS.mel_dim)

    # prenet
    self.prenet_fc1 = hk.Linear(256, with_bias=True)
    self.prenet_fc2 = hk.Linear(256, with_bias=True)
    # posnet
    self.postnet_convs = [hk.Conv1D(FLAGS.postnet_dim, 5) for _ in range(4)] + [hk.Conv1D(FLAGS.mel_dim, 5)]
    self.postnet_bns = [hk.BatchNorm(True, True, 0.9) for _ in range(4)] + [None]
Exemple #28
0
def mlp(
    x: tp.Union[jnp.ndarray, JAXSparse],
    is_training: bool,
    ids=None,
    num_classes: tp.Optional[int] = None,
    hidden_filters: tp.Union[int, tp.Iterable[int]] = 64,
    dropout_rate: float = 0.8,
    use_batch_norm: bool = False,
    use_layer_norm: bool = False,
    use_renormalize: bool = False,
    use_gathered_batch_norm: bool = False,
    activation: Activation = jax.nn.relu,
    final_activation: Activation = lambda x: x,
    input_dropout_rate: tp.Optional[float] = None,
    batch_norm_decay: float = 0.9,
    renorm_scale: bool = True,
    w_init=None,
):
    assert (sum((use_batch_norm, use_layer_norm, use_renormalize,
                 use_gathered_batch_norm)) <= 1)
    if input_dropout_rate is None:
        input_dropout_rate = dropout_rate
    if isinstance(hidden_filters, int):
        hidden_filters = (hidden_filters, )

    x = dropout(x, input_dropout_rate, is_training=is_training)
    for filters in hidden_filters:
        x = Linear(filters, w_init=w_init)(x)
        if use_batch_norm:
            x = hk.BatchNorm(renorm_scale, True, batch_norm_decay)(x,
                                                                   is_training)
        if use_layer_norm:
            x = hk.LayerNorm(0, renorm_scale, True)(x)
        if use_renormalize:
            x = Renormalize(renorm_scale, True)(x)
        if use_gathered_batch_norm:
            assert ids is not None
            x = GatheredBatchNorm(True, True,
                                  batch_norm_decay)(x,
                                                    is_training=is_training,
                                                    ids=ids)
        x = activation(x)
        x = dropout(x, dropout_rate, is_training=is_training)
    if num_classes is not None:
        x = hk.Linear(num_classes, w_init=w_init)(x)
    return final_activation(x)
Exemple #29
0
 def func(S, A, is_training):
     flatten = hk.Flatten()
     batch_norm = hk.BatchNorm(create_scale=True,
                               create_offset=True,
                               decay_rate=0.95)
     batch_norm = partial(batch_norm, is_training=is_training)
     seq = hk.Sequential((
         hk.Linear(7),
         batch_norm,
         jnp.tanh,
         hk.Linear(3),
         jnp.tanh,
         hk.Linear(1),
         jnp.ravel,
     ))
     X = jnp.concatenate((flatten(S), flatten(A)), axis=-1)
     return seq(X)
Exemple #30
0
    def __init__(self, num_blocks=[5, 5, 5], num_classes=10):
        super().__init__()

        self.conv1 = hk.Conv2D(
            output_channels=16,
            # output_channels=64,
            kernel_shape=3,
            stride=1,
            with_bias=False,
            padding='SAME',
            data_format='NCHW')
        self.bn1 = hk.BatchNorm(create_scale=True,
                                create_offset=True,
                                decay_rate=0.9,
                                data_format='NC...')
        self.layer1 = MultiBlock(16, 16, [1] + [1] * (num_blocks[0] - 1))
        self.layer2 = MultiBlock(16, 32, [2] + [1] * (num_blocks[1] - 1))
        self.layer3 = MultiBlock(32, 64, [2] + [1] * (num_blocks[2] - 1))
        self.linear = hk.Linear(num_classes)