Ejemplo n.º 1
0
    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.
        Args:
          h: Inputs, [B, T, H].
          mask: Padding mask, [B, T].
          is_training: Whether we're training or not.
        Returns:
          Array of shape [B, T, H].
        """

        init_scale = 2. / self._num_layers
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = SelfAttention(num_heads=self._num_heads,
                                   key_size=64,
                                   w_init_scale=init_scale,
                                   name=f'h{i}_attn')(h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h
Ejemplo n.º 2
0
 def __call__(self, x, lengths):
     x = self.embed(x)
     x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     B, L, D = x.shape
     mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1)
     h0c0_fwd = self.lstm_fwd.initial_state(B)
     new_hx_fwd, new_hxcx_fwd = hk.dynamic_unroll(self.lstm_fwd,
                                                  x,
                                                  h0c0_fwd,
                                                  time_major=False)
     x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1),
                                    (x, mask))
     h0c0_bwd = self.lstm_bwd.initial_state(B)
     new_hx_bwd, new_hxcx_bwd = hk.dynamic_unroll(self.lstm_bwd,
                                                  (x_bwd, mask_bwd),
                                                  h0c0_bwd,
                                                  time_major=False)
     x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)),
                         axis=-1)
     return x
Ejemplo n.º 3
0
    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.

    Args:
      h: Inputs, [B, T, H].
      mask: Padding mask, [B, T].
      is_training: Whether we're training or not.

    Returns:
      Array of shape [B, T, H].
    """

        init_scale = 2. / np.sqrt(self._num_layers)
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        h = layer_norm(h)
        for _ in range(self._num_layers):
            h_attn = CausalSelfAttention(self._num_heads, init_scale)(h, mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = layer_norm(h + h_attn)
            h_dense = DenseBlock(init_scale)(h)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = layer_norm(h + h_dense)

        return h
Ejemplo n.º 4
0
 def forward(self, x, is_training):
     # Block 1
     x = jax.nn.relu(self.conv1_1(x))
     x = self.bn1_1(x, is_training)
     x = jax.nn.relu(self.conv1_2(x))
     x = self.bn1_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.2, x)
     # Block 2
     x = jax.nn.relu(self.conv2_1(x))
     x = self.bn2_1(x, is_training)
     x = jax.nn.relu(self.conv2_2(x))
     x = self.bn2_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.3, x)
     # Block 3
     x = jax.nn.relu(self.conv3_1(x))
     x = self.bn3_1(x, is_training)
     x = jax.nn.relu(self.conv3_2(x))
     x = self.bn3_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.4, x)
     # Linear part
     x = hk.Flatten()(x)
     x = jax.nn.relu(self.lin1(x))
     x = self.bn4(x, is_training)
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.5, x)
     x = self.lin2(x)
     return x  # logits
Ejemplo n.º 5
0
    def net_fn(inputs):
        """Function representing a linear layer with learned noise distribution."""
        num_inputs = inputs.shape[-1]
        mu_initializer = _dqn_default_initializer(num_inputs)
        mu_layer = hk.Linear(num_outputs,
                             name='mu',
                             with_bias=with_bias,
                             w_init=mu_initializer,
                             b_init=mu_initializer)
        sigma_initializer = hk.initializers.Constant(  #
            weight_init_stddev / jnp.sqrt(num_inputs))
        sigma_layer = hk.Linear(num_outputs,
                                name='sigma',
                                with_bias=True,
                                w_init=sigma_initializer,
                                b_init=sigma_initializer)

        # Broadcast noise over batch dimension.
        input_noise_sqrt = make_noise_sqrt(hk.next_rng_key(), [1, num_inputs])
        output_noise_sqrt = make_noise_sqrt(hk.next_rng_key(),
                                            [1, num_outputs])

        # Factorized Gaussian noise.
        mu = mu_layer(inputs)
        noisy_inputs = input_noise_sqrt * inputs
        sigma = sigma_layer(noisy_inputs) * output_noise_sqrt
        return mu + sigma
Ejemplo n.º 6
0
def func_boxspace(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    mu = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    logvar = hk.Sequential((
        hk.Flatten(),
        hk.Linear(8),
        jax.nn.relu,
        partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
        partial(batch_norm, is_training=is_training),
        hk.Linear(8),
        jnp.tanh,
        hk.Linear(onp.prod(boxspace.shape)),
        hk.Reshape(boxspace.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}
Ejemplo n.º 7
0
    def net_fn(inputs):
        """Function representing multi-head DQN Q-network."""
        network = hk.Sequential([
            dqn_torso(),
            dqn_value_head(num_heads * num_actions),
        ])
        network_output = network(inputs)
        multi_head_output = jnp.reshape(network_output,
                                        (-1, num_heads, num_actions))
        mask = jax.random.choice(key=hk.next_rng_key(),
                                 a=2,
                                 shape=(
                                     multi_head_output.shape[0],
                                     num_heads,
                                 ),
                                 p=binomial_probabilities)
        random_head_indices = jax.random.choice(
            key=hk.next_rng_key(),
            a=num_heads,
            shape=(multi_head_output.shape[0], ))
        random_head_q_value = jnp.reshape(
            multi_head_output[:, random_head_indices], (-1, num_actions))

        # TODO: make the q values (used for eval) the output of voting or weighted mean.
        # Currently random head q value used as placeholder
        return MultiHeadQNetworkOutputs(
            q_values=jnp.mean(multi_head_output, axis=1),
            multi_head_output=multi_head_output,
            random_head_q_value=random_head_q_value)
Ejemplo n.º 8
0
    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.

    Args:
      h: Inputs, [B, T, H].
      mask: Padding mask, [B, T].
      is_training: Whether we're training or not.

    Returns:
      Array of shape [B, T, H].
    """

        init_scale = 2. / np.sqrt(self._num_layers)
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        # Note: names chosen to approximately match those used in the GPT-2 code;
        # see https://github.com/openai/gpt-2/blob/master/src/model.py.
        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = CausalSelfAttention(self._num_heads,
                                         init_scale,
                                         name=f'h{i}_attn')(h_norm, mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h
Ejemplo n.º 9
0
    def get_latents(self, encodings, probs_b, training):
        """Read out latents (z) form input encodings for a single segment."""
        readout_mask = probs_b[:, 1:, None]  # Offset readout by 1 to left.
        readout = (encodings[:, :-1] * readout_mask).sum(1)
        hidden = nn.relu(self.head_z_1(readout))
        logits_z = self.head_z_2(hidden)

        # Gaussian latents.
        if self.latent_dist == 'gaussian':
            if training:
                mu, log_var = jnp.split(logits_z, 2, axis=1)
                sample_z = utils.gaussian_sample(hk.next_rng_key(), mu,
                                                 log_var)
            else:
                sample_z = logits_z[:, :self.latent_dim]

        # Concrete / Gumbel softmax latents.
        elif self.latent_dist == 'concrete':
            if training:
                sample_z = utils.gumbel_softmax_sample(hk.next_rng_key(),
                                                       logits_z,
                                                       temp=self.temp_z)
            else:
                sample_z_idx = jnp.argmax(logits_z, axis=1)
                sample_z = utils.to_one_hot(sample_z_idx, logits_z.size(1))
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

        return logits_z, sample_z
Ejemplo n.º 10
0
 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
     hiddens = x.shape[-1]
     x = conv1d(x, num_units=self._dense_dim, init_scale=self._init_scale)
     x = jax.nn.relu(x)
     x = hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
     x = conv1d(x, num_units=hiddens, init_scale=self._init_scale)
     return hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
Ejemplo n.º 11
0
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape)
        logits = Decoder(self._hidden_size, self._output_shape)(z)

        p = jax.nn.sigmoid(logits)
        image = jax.random.bernoulli(hk.next_rng_key(), p)

        return VAEOutput(image, mean, stddev, logits)
Ejemplo n.º 12
0
        def forward_fn(x: jnp.ndarray) -> jnp.ndarray:
            linear_1 = linear_with_dropout(3, 0.5)
            transformed_linear = hk.transform(linear_1)

            inner_params = hk.experimental.lift(transformed_linear.init)(
                hk.next_rng_key(), x, True)

            def fun(_params, _rng, h):
                return transformed_linear.apply(_params, _rng, h, True)

            z = deq(inner_params, hk.next_rng_key(), x, fun, max_iter)
            return hk.Linear(output_size, name='l2', with_bias=False)(z)
Ejemplo n.º 13
0
        def forward_fn(x: jnp.ndarray) -> jnp.ndarray:
            linear_1 = hk.Linear(output_size,
                                 name='l1',
                                 w_init=hk.initializers.Constant(1),
                                 b_init=hk.initializers.Constant(1))
            transformed_linear = hk.transform(linear_1)
            inner_params = hk.experimental.lift(transformed_linear.init)(
                hk.next_rng_key(), x)

            z = deq(inner_params, hk.next_rng_key(), x,
                    transformed_linear.apply, max_iter)
            return z
Ejemplo n.º 14
0
    def _elbo_fun(input_data):
        if _ENCODER.value is EncoderArch.color_mnist_mlp_encoder:
            encoder = encoders.ColorMnistMLPEncoder(_LATENT_DIM.value)

        if _DECODER.value is DecoderArch.color_mnist_mlp_decoder:
            decoder = decoders.ColorMnistMLPDecoder(_OBS_VAR.value)

        vae_obj = vae.VAE(encoder, decoder, _RHO.value)

        if _MODEL.value is Model.vae:
            return vae_obj.vae_elbo(input_data, hk.next_rng_key())
        else:
            return vae_obj.avae_elbo(input_data, hk.next_rng_key())
Ejemplo n.º 15
0
  def init_fn(rng: Optional[Union[PRNGKey]],
              inputs: Mapping[str, jnp.ndarray],
              batch_axes=(),
              return_initial_output=False,
              **kwargs
  ) -> Tuple[Params, State]:
    """ Initializes your function collecting parameters and state. """
    rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
    with new_custom_context(rng=rng) as ctx:
      # Create the model
      model = create_fun()

      # Load the batch axes for the inputs
      Layer.batch_axes = batch_axes

      key = hk.next_rng_key()

      # Initialize the model
      outputs = model(inputs, key, **kwargs)

      # Unset the batch axes
      Layer.batch_axes = ()

    nonlocal constants
    params, state, constants = ctx.collect_params(), ctx.collect_initial_state(), ctx.collect_constants()

    if return_initial_output:
      return params, state, outputs

    return params, state
Ejemplo n.º 16
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Predict logits or values

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch of graphs.
        adj : ndarray of shape (batch_size, N, N)
            Batch adjacency matrix.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        out : ndarray of shape (batch_size, n_out)
            Predicator output.
        """
        predicator_dropout = self.predicator_dropout if is_training is True else 0.0
        node_feats = self.gcn(node_feats, adj, is_training)
        # pooling
        graph_feat = self.pooling(node_feats)
        if predicator_dropout != 0.0:
            graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout,
                                    graph_feat)
        graph_feat = self.fc(graph_feat)
        graph_feat = self.activation(graph_feat)
        out = self.out(graph_feat)
        return out
Ejemplo n.º 17
0
    def __call__(self, q: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray:
        """Computes the relative position embedding.

    Args:
      q: The query.
      k: The key.

    Returns:
      Relative position embedding.
    """
        # Use key instead of query to obtain the length.
        batch_size, key_length, num_heads, head_dim = list(k.shape)
        # Content based addressing and global content bias
        content_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_w_bias, k)

        # Relative position encoding
        positional_encodings = self._sinusoidal_pos_emb(key_length, batch_size)
        positional_encodings = hk.dropout(hk.next_rng_key(),
                                          self._dropout_rate,
                                          positional_encodings)
        rel_pos_emb = hk.Conv1D(
            output_channels=self._dim,
            kernel_shape=1,
            with_bias=False,
            w_init=init.RandomNormal(
                stddev=self._init_scale))(positional_encodings)
        rel_pos_emb = jnp.reshape(
            rel_pos_emb, [batch_size, key_length, num_heads, head_dim])

        # Content dependent positional bias and global positional bias
        rel_pos_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_r_bias,
                                   rel_pos_emb)
        rel_pos_score = relative_shift(rel_pos_score)
        assert content_score.shape == rel_pos_score.shape
        return content_score + rel_pos_score
    def generate_initial(self, context, length):
        # slice last token off the context (we use that in generate_once to generate the first new token)
        last = context[-1:]
        context = context[:-1]

        input_len = context.shape[0]

        if self.rpe is not None:
            attn_bias = self.rpe(input_len, input_len, self.heads_per_shard,
                                 32)
        else:
            attn_bias = 0

        x = self.embed(context)

        states = []

        for l in self.transformer_layers:
            res, layer_state = l.get_init_decode_state(x, length - 1,
                                                       attn_bias)
            x = x + res
            states.append(layer_state)

        return self.proj(x), (last.astype(jnp.uint32), states,
                              hk.next_rng_key())
Ejemplo n.º 19
0
    def get_boundaries(self, encodings, segment_id, lengths, training):
        """Get boundaries (b) for a single segment in batch."""
        if segment_id == self.max_num_segments - 1:
            # Last boundary is always placed on last sequence element.
            logits_b = None
            # sample_b = jnp.zeros_like(encodings[:, :, 0]).scatter_(
            #     1, jnp.expand_dims(lengths, -1) - 1, 1)
            sample_b = jnp.zeros_like(encodings[:, :, 0])
            sample_b = jax.ops.index_update(
                sample_b, jax.ops.index[jnp.arange(len(lengths)), lengths - 1],
                1)
        else:
            hidden = nn.relu(self.head_b_1(encodings))
            logits_b = jnp.squeeze(self.head_b_2(hidden), -1)
            # Mask out first position with large neg. value.
            neg_inf = jnp.ones((encodings.shape[0], 1)) * utils.NEG_INF
            # TODO(tkipf): Mask out padded positions with large neg. value.
            logits_b = jnp.concatenate([neg_inf, logits_b[:, 1:]], axis=1)
            if training:
                sample_b = utils.gumbel_softmax_sample(hk.next_rng_key(),
                                                       logits_b,
                                                       temp=self.temp_b)
            else:
                sample_b_idx = jnp.argmax(logits_b, axis=1)
                sample_b = nn.one_hot(sample_b_idx, logits_b.shape[1])

        return logits_b, sample_b
Ejemplo n.º 20
0
 def __call__(self, inputs, is_training):
     dropout_rate = self._dropout_rate if is_training else 0.0
     h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs)
     h = hk.Linear(self._vocab_size, with_bias=False)(h)
     return hk.BatchNorm(create_scale=False,
                         create_offset=False,
                         decay_rate=0.9)(h, is_training)
Ejemplo n.º 21
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 graph_idx: jnp.array, is_training: bool) -> jnp.ndarray:
        """Predict logits or values

        Parameters
        ----------
        node_feats : ndarray of shape (N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch
        adj : ndarray of shape (2, E)
            Batch adjacency list.
            E is the total number of edges in the batch
        graph_idx : ndarray of shape (N,)
            This idx indicate a graph number for node_feats in the batch.
            When the two nodes shows the same graph idx, these belong to the same graph.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        out : ndarray of shape (batch_size, n_out)
            Predicator output.
        """
        predicator_dropout = self.predicator_dropout if is_training is True else 0.0
        node_feats = self.gcn(node_feats, adj, is_training)
        # pooling
        graph_feat = self.pooling(node_feats, graph_idx)
        if predicator_dropout != 0.0:
            graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout,
                                    graph_feat)
        graph_feat = self.fc(graph_feat)
        graph_feat = self.activation(graph_feat)
        out = self.out(graph_feat)
        return out
Ejemplo n.º 22
0
        def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray:

            dropout_rate = self._dropout_rate if is_training else 0.
            h = hk.Linear(self.output_size,
                          w_init=hk.initializers.Constant(1),
                          b_init=hk.initializers.Constant(1))(x)
            return hk.dropout(hk.next_rng_key(), dropout_rate, h)
Ejemplo n.º 23
0
 def wrapped(*args):
     out = fn(*args)
     if is_training:
         mask = hk.dropout(hk.next_rng_key(), dropout_rate,
                           jnp.ones([out.shape[0], 1]))
         out = out * mask
     return out
Ejemplo n.º 24
0
 def __call__(self, x, is_training=True, return_metrics=False):
     """Return the output of the final layer without any [log-]softmax."""
     # Stem
     outputs = {}
     out = self.initial_conv(x)
     out = hk.max_pool(out,
                       window_shape=(1, 3, 3, 1),
                       strides=(1, 2, 2, 1),
                       padding='SAME')
     if return_metrics:
         outputs.update(base.signal_metrics(out, 0))
     # Blocks
     for i, block in enumerate(self.blocks):
         out, res_avg_var = block(out, is_training=is_training)
         if return_metrics:
             outputs.update(base.signal_metrics(out, i + 1))
             outputs[f'res_avg_var_{i}'] = res_avg_var
     # Final-conv->activation, pool, dropout, classify
     pool = jnp.mean(self.activation(out), [1, 2])
     outputs['pool'] = pool
     # Optionally apply dropout
     if self.drop_rate > 0.0 and is_training:
         pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
     outputs['logits'] = self.fc(pool)
     return outputs
Ejemplo n.º 25
0
    def layer(
        self,
        x: jnp.ndarray,
        latents: jnp.ndarray,
        output_channels: int,
        upsample: bool = False,
    ) -> jnp.ndarray:
        if upsample:
            conv = UpsampleConv2D(
                output_channels=output_channels,
                kernel_shape=3,
                upsample_factor=2,
                resample_kernel=self.resample_kernel,
            )
        else:
            conv = ModulatedConv2D(output_channels=output_channels,
                                   kernel_shape=3,
                                   padding="SAME")
        y = conv(x, latents)

        if self.data_format == ChannelOrder.channels_first:
            noise_shape = (y.shape[0], 1, y.shape[2], y.shape[3])
        else:
            noise_shape = (y.shape[0], y.shape[1], y.shape[2], 1)

        key = hk.next_rng_key()
        noise = jax.random.normal(key, shape=noise_shape, dtype=y.dtype)
        noise_strength = hk.get_parameter("noise_strength", (1, 1, 1, 1),
                                          dtype=y.dtype,
                                          init=jnp.zeros)
        y += noise_strength * noise
        return self.activation_function(y)
Ejemplo n.º 26
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Update node features.

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch of graphs.
        adj : ndarray of shape (batch_size, N, N)
            Batch adjacency matrix.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (batch_size, N, out_feats)
            Batch new node features.
        """
        dropout = self.dropout if is_training is True else 0.0

        # for batch data
        new_node_feats = jax.vmap(self._update_nodes)(node_feats, adj)
        if self.bias:
            new_node_feats += self.b
        new_node_feats = self.activation(new_node_feats)

        if dropout != 0.0:
            new_node_feats = hk.dropout(hk.next_rng_key(), dropout,
                                        new_node_feats)
        if self.batch_norm:
            new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats,
                                                           is_training)

        return new_node_feats
Ejemplo n.º 27
0
 def fn(x):
     if dropout:
         x = hk.dropout(hk.next_rng_key(), 0.5, x)
     if batchnorm:
         x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(
             x, is_training=True
         )
     return x
Ejemplo n.º 28
0
 def maybe_dropedge(x):
     """Dropout on edge messages."""
     if not is_training:
         return x
     return x * hk.dropout(
         hk.next_rng_key(),
         dropedge_rate,
         jnp.ones([x.shape[0], 1]),
     )
Ejemplo n.º 29
0
def func_discrete_type2(S, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential(
        (hk.Flatten(), hk.Linear(8), jax.nn.relu,
         partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
         partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh,
         hk.Linear(discrete.n * discrete.n),
         hk.Reshape((discrete.n, discrete.n)), jax.nn.softmax))
    return seq(S)
Ejemplo n.º 30
0
def func_discrete_type1(S, A, is_training):
    batch_norm = hk.BatchNorm(False, False, 0.99)
    seq = hk.Sequential(
        (hk.Flatten(), hk.Linear(8), jax.nn.relu,
         partial(hk.dropout, hk.next_rng_key(), 0.25 if is_training else 0.),
         partial(batch_norm, is_training=is_training), hk.Linear(8), jnp.tanh,
         hk.Linear(discrete.n), jax.nn.softmax))
    X = jax.vmap(jnp.kron)(S, A)
    return seq(X)