Beispiel #1
0
    def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1):
        batch_size, sequence_length, hidden_size = hidden_states.shape

        # project to codevector dim
        hidden_states = self.weight_proj(hidden_states)
        hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1)

        if not deterministic:
            # sample code vector probs via gumbel in differentiateable way
            gumbel_rng = self.make_rng("gumbel")
            gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape)
            codevector_probs = nn.softmax((hidden_states + gumbels) / temperature)

            # compute perplexity
            codevector_soft_dist = nn.softmax(
                hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1
            )
            perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
        else:
            # take argmax in non-differentiable way
            # comptute hard codevector distribution (one hot)
            codevector_idx = hidden_states.argmax(axis=-1)
            codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0
            codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1)
            perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)

        codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1)
        # use probs to retrieve codevectors
        codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors
        codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
        codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1)

        return codevectors, perplexity
Beispiel #2
0
    def _predict_color(self, input_q, input_k, randomized):  # pylint: disable=arguments-differ
        """Function to predict the color by aggregating information from neighbouring views.

    Args:
      input_q: query, with shape (bs, 1, q_feature_dim)
      input_k: key, with shape (bs, near_cam, k_feature_dim')
      randomized: True during training.

    Returns:
      rgb: color prediction (bs, 3)
      neighbor_attn_weights: attention weights
    """
        input_q = self.query_transform2(input_q[:, None])

        input_k = self.key_transform2(input_k)
        input_q = jnp.concatenate([input_q, input_k], axis=-2)

        out = self.view_transformer(
            input_q,
            deterministic=not randomized,
        )

        refined_query = out[:, 0:1]
        refined_key = out[:, 1:]
        refined_query = jnp.tile(refined_query, (1, refined_key.shape[-2], 1))

        concat_key_query = jnp.concatenate([refined_query, refined_key],
                                           axis=-1)
        neighbor_attn_weights = self.view_correspondence(concat_key_query)
        neighbor_attn_weights = nn.softmax(neighbor_attn_weights, axis=-2)

        raw_rgb = self.rgb_dense((refined_key * neighbor_attn_weights).sum(-2))
        rgb = self.render_config.rgb_activation(raw_rgb)

        return rgb, neighbor_attn_weights
Beispiel #3
0
 def __call__(self, x):
     initializer = nn.initializers.variance_scaling(scale=1.0 /
                                                    jnp.sqrt(3.0),
                                                    mode='fan_in',
                                                    distribution='uniform')
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = x.reshape((-1))  # flatten
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=self.num_actions * self.num_atoms,
                  kernel_init=initializer)(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Beispiel #4
0
  def _get_avg_features(self, input_q, input_k, randomized):
    """Function that aggregate feature over the projection on the epipolar line.

    Args:
      input_q: query, with shape (bs, 1, q_feature_dim)
      input_k: key, with shape (bs, near_cam, projections, k_feature_dim)
      randomized: True during training

    Returns:
      out: Average features (bs, near_cam, _)
      epipolar_attn_weights: attention weights
    """
    # Change shape of query from (BS, Q) -> (BS, NearCam, 1, Q)
    input_q = jnp.tile(input_q[:, None, None], (1, input_k.shape[1], 1, 1))
    input_q = self.query_transform(input_q)
    input_k = self.key_transform(input_k)

    # Concatenate the query to the keys
    input_k = jnp.concatenate([input_q, input_k], axis=-2)
    out = self.epipolar_transformer(
        input_k,
        deterministic=not randomized,
    )
    refined_query = out[Ellipsis, 0:1, :]  # Get refined query
    refined_key = out[Ellipsis, 1:, :]

    refined_query = jnp.tile(refined_query, (1, 1, refined_key.shape[-2], 1))
    # Predict attetion weights for averaging the key
    concat_query_key = jnp.concatenate([refined_query, refined_key], axis=-1)
    epipolar_attn_weights = self.epipolar_correspondence(concat_query_key)
    epipolar_attn_weights = nn.softmax(epipolar_attn_weights, axis=-2)
    out = (epipolar_attn_weights * refined_key).sum(-2)

    return out, epipolar_attn_weights
Beispiel #5
0
    def __call__(self, keys: Array, mask: Array) -> Array:
        """Applies model  to the input keys and mask.

    Args:
      keys: The inputs for which to compute an attention score. Shape:
        <float32>[batch_size, seq_length, embeddings_size].
      mask: A mask that determinines which values in `keys` are valid. Only
        values for which the mask is True will get non-zero attention scores.
        <bool>[batch_size, seq_length].

    Returns:
      The normalized attention scores. <float32>[batch_size, seq_length].
    """
        hidden = nn.Dense(self.hidden_size, name='keys', use_bias=False)(keys)
        energy = nn.tanh(hidden)
        scores = nn.Dense(1, name='energy', use_bias=False)(energy)
        scores = scores.squeeze(
            -1)  # New shape: <float32>[batch_size, seq_len].
        scores = jnp.where(mask, scores,
                           -jnp.inf)  # Using exp(-inf) = 0 below.
        scores = nn.softmax(scores, axis=-1)

        # Captures the scores if 'intermediates' is mutable, otherwise does nothing.
        self.sow('intermediates', 'attention', scores)

        return scores
Beispiel #6
0
 def quantize(self, latent_indices, soft_quantize=False):
     """Returns embedding tensor for a batch of indices."""
     embeddings = self.vq_emb.value
     w = embeddings.swapaxes(1, 0)
     if soft_quantize:
         # Given logits over latent states instead.
         return jnp.dot(nn.softmax(latent_indices), w)
     else:
         return w[latent_indices]
Beispiel #7
0
    def __call__(self, inputs, is_training: bool):
        assert len(self.strides) == 3
        assert inputs.ndim == 3
        q_strides, k_strides, v_strides = self.strides
        b, l, c = inputs.shape
        out_ch = self.out_ch if self.out_ch is not None else c
        spatial_ch = int(jnp.ceil(jnp.sqrt(l)))
        inputs = jnp.pad(inputs, ((0, 0), (0, spatial_ch**2 - l), (0, 0)))
        inputs = rearrange(inputs, 'b (H W) c -> b H W c', W=spatial_ch)

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * self.head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init,
                            bias_init=self.bias_init)

        query = conv_proj(strides=q_strides)(inputs, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs, is_training=is_training)

        query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads)

        query = query / jnp.sqrt(self.head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        if (self.num_heads * self.head_ch) == self.out_ch:
            output = rearrange(attn_scores, '... q h d -> ... q (h d)')
        else:
            output = nn.DenseGeneral(features=self.out_ch,
                                     axis=(-2, -1),
                                     use_bias=self.use_bias,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(attn_scores)
        return output
Beispiel #8
0
    def __call__(
        self,
        query,
        key,
        value,
        mask,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):

        bs, q_len, dim = query.shape
        k_len = key.shape[1]
        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
        # assert key.size() == value.size()

        dim_per_head = self.dim // self.n_heads

        mask_reshp = (bs, 1, 1, k_len)

        def shape(x):
            """separate heads"""
            return x.reshape(bs, -1, self.n_heads,
                             dim_per_head).transpose(0, 2, 1, 3)

        def unshape(x):
            """group heads"""
            return x.transpose(0, 2, 1, 3).reshape(bs, -1,
                                                   self.n_heads * dim_per_head)

        q = shape(self.q_lin(query))  # (bs, n_heads, q_len, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_len, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_len, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_len, dim_per_head)
        scores = jnp.matmul(q, k.transpose(0, 1, 3,
                                           2))  # (bs, n_heads, q_len, k_len)
        mask = jnp.reshape(mask, mask_reshp)

        mask = mask.astype(scores.dtype)
        scores = scores - 1e30 * (1.0 - mask)

        weights = nn.softmax(scores, axis=-1)  # (bs, n_heads, q_len, k_len)
        weights = self.dropout(weights, deterministic=deterministic)

        context = jnp.matmul(weights, v)  # (bs, n_heads, q_len, dim_per_head)
        context = unshape(context)  # (bs, q_len, dim)
        context = self.out_lin(context)  # (bs, q_len, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context, )
Beispiel #9
0
 def __call__(self, x, support):
   initializer = nn.initializers.xavier_uniform()
   x = x.astype(jnp.float32)
   x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1),
               kernel_init=initializer)(x)
   x = nn.relu(x)
   x = x.reshape(-1)  # flatten
   x = nn.Dense(features=self.num_actions * self.num_atoms,
                kernel_init=initializer)(x)
   logits = x.reshape((self.num_actions, self.num_atoms))
   probabilities = nn.softmax(logits)
   q_values = jnp.sum(support * probabilities, axis=1)
   return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
 def __call__(self, x):
     dtype = jnp.float32
     x = x.reshape((x.shape[0], -1))
     x = nn.Dense(features=2 * common.BOARD_SIZE**2,
                  name='hidden1',
                  dtype=dtype)(x)
     x = nn.relu(x)
     x = nn.Dense(features=common.BOARD_SIZE**2,
                  name='hidden2',
                  dtype=dtype)(x)
     x = nn.relu(x)
     x = nn.Dense(features=common.BOARD_SIZE**2, name='logits',
                  dtype=dtype)(x)
     policy_probabilities = nn.softmax(x)
     return policy_probabilities
Beispiel #11
0
 def __call__(self, x, support):
     x = x.astype(jnp.float32)
     x = x.reshape((-1))  # flatten
     if self.min_vals is not None:
         x -= self._min_vals
         x /= self._max_vals - self._min_vals
         x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
     for layer in self.layers:
         x = layer(x)
         x = nn.relu(x)
     x = self.final_layer(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.sum(support * probabilities, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == inputs_kv.ndim == 3

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype)

        query = dense(name='queries')(inputs_q)
        key = dense(name='keys')(inputs_kv)
        value = dense(name='values')(inputs_kv)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query,
                                  key)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights, value)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype)(attn_scores)

        output = nn.Dropout(rate=self.out_dropout_rate)(
            output, deterministic=not is_training)
        return output
Beispiel #13
0
      def __call__(self, x):
        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :self.num_atoms] = onp.arange(
              1, self.num_atoms + 1)
          return to_pick_first_action

        x = x.astype(jnp.float32)
        x = x.reshape((-1))  # flatten
        x = linen.Dense(features=self.num_actions * self.num_atoms,
                        kernel_init=custom_init,
                        bias_init=linen.initializers.ones)(x)
        logits = x.reshape((self.num_actions, self.num_atoms))
        probabilities = linen.softmax(logits)
        qs = jnp.mean(logits, axis=1)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Beispiel #14
0
 def __call__(self, x, support):
     initializer = nn.initializers.xavier_uniform()
     x = x.astype(jnp.float32)
     x = x.reshape((-1))  # flatten
     x -= gym_lib.ACROBOT_MIN_VALS
     x /= gym_lib.ACROBOT_MAX_VALS - gym_lib.ACROBOT_MIN_VALS
     x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=self.num_actions * self.num_atoms,
                  kernel_init=initializer)(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.sum(support * probabilities, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Beispiel #15
0
    def __call__(self, inputs):
        cfg = self.config
        assert inputs.ndim == 3

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(cfg.num_heads, cfg.dim_head),
                        use_bias=False,
                        kernel_init=cfg.kernel_init,
                        precision=cfg.precision)

        query, key, value = (dense(dtype=cfg.dtype)(inputs),
                             dense(dtype=cfg.dtype)(inputs),
                             dense(dtype=cfg.dtype)(inputs))

        query = query / jnp.sqrt(cfg.dim_head).astype(cfg.dtype)

        attn_weights = jnp.einsum('b q h d, b k h d -> b h q k',
                                  query,
                                  key,
                                  precision=cfg.precision)
        attn_weights = nn.softmax(attn_weights).astype(cfg.dtype)

        if cfg.shared_theta:
            attn_weights = self.theta_transform(attn_weights)
        else:
            attn_weights = ThetaTransform(config=cfg)(attn_weights)

        attn_weights = nn.LayerNorm()(attn_weights)

        out = jnp.einsum('b h q k, b q h d -> b k h d',
                         attn_weights,
                         value,
                         precision=cfg.precision)

        if (cfg.num_heads * cfg.dim_head) != cfg.emb_dim:
            out = nn.DenseGeneral(features=cfg.emb_dim,
                                  axis=(-2, -1),
                                  dtype=cfg.dtype,
                                  precision=cfg.precision,
                                  kernel_init=cfg.kernel_init,
                                  bias_init=cfg.bias_init)(out)
        else:
            out = rearrange(out, 'b k h d -> b k (h d)')

        return out
Beispiel #16
0
    def __call__(self, x, context=None, mask=None, deterministic=False):
        h = self.heads
        dim = self.head_features * h

        q = nn.Dense(dim, use_bias=False)(x)
        k, v = nn.Dense(dim * 2, use_bias=False)(default(context, x)).split(2, axis=-1)

        q, k, v = map(
            lambda arr: rearrange(arr, "b n (h d) -> (b h) n d", h=h), (q, k, v)
        )
        sim = jnp.einsum("b i d, b j d -> b i j", q, k) * self.head_features ** -0.5
        attn = nn.softmax(sim, axis=-1)
        out = jnp.einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)

        out = nn.Dense(x.shape[-1])(out)
        out = nn.Dropout(self.dropout)(out, deterministic=deterministic)
        return out
Beispiel #17
0
    def _predict_color(self, input_q, input_k, learned_embedding, randomized):  # pylint: disable=arguments-differ
        """Function to predict the color by aggreagating information form neighbouring views.

    Args:
      input_q: query, with shape (bs, 1, q_feature_dim)
      input_k: key, with shape (bs, near_cam, k_feature_dim')
      learned_embedding: learned embedding for reference views
      randomized: True during training.

    Returns:
      rgb: color prediction (bs, 3)
      neighbor_attn_weights: attention weights
    """
        input_q = self.query_transform2(input_q[:, None])

        if learned_embedding is not None:
            # Optionally add training view camera embedding
            # learned_embedding has shape of (B, N, P, _) , the second last dimension
            # was repilicated P time to be able to concatenate to the key vales. Here
            # we ony need to choose one of the to get shape (B, N, _)
            camera_embedding = learned_embedding[Ellipsis, 0, :]
            input_k = jnp.concatenate([input_k, camera_embedding], axis=-1)

        input_k = self.key_transform2(input_k)
        input_q = jnp.concatenate([input_q, input_k], axis=-2)

        out = self.view_transformer(
            input_q,
            deterministic=not randomized,
        )

        refined_query = out[:, 0:1]
        refined_key = out[:, 1:]
        refined_query = jnp.tile(refined_query, (1, refined_key.shape[-2], 1))

        concat_key_query = jnp.concatenate([refined_query, refined_key],
                                           axis=-1)
        neighbor_attn_weights = self.view_correspondence(concat_key_query)
        neighbor_attn_weights = nn.softmax(neighbor_attn_weights, axis=-2)

        raw_rgb = self.rgb_dense((refined_key * neighbor_attn_weights).sum(-2))
        rgb = self.render_config.rgb_activation(raw_rgb)

        return rgb, neighbor_attn_weights
Beispiel #18
0
    def __call__(self, inputs_q):
        """Applies multi-head self-attention on the input data.
        Arguments:
            inputs_q:   [batch_size, height, width, dim]
        Returns:
            output:     [batch_size, height, width, dim]
        """
        cfg = self.config
        conv = partial(nn.Conv,
                       features=cfg.num_heads * cfg.head_dim,
                       kernel_size=(1, 1),
                       use_bias=False,
                       precision=cfg.precision,
                       kernel_init=cfg.kernel_init)

        query, key, value = (conv(dtype=cfg.dtype, name="query")(inputs_q),
                             conv(dtype=cfg.dtype, name="key")(inputs_q),
                             conv(dtype=cfg.dtype, name="value")(inputs_q))
        query, key, value = (rearrange(query,
                                       'b H W (h d) -> b h H W d',
                                       h=cfg.num_heads),
                             rearrange(key,
                                       'b H W (h d) -> b h H W d',
                                       h=cfg.num_heads),
                             rearrange(value,
                                       'b H W (h d) -> b h H W d',
                                       h=cfg.num_heads))

        query = query / jnp.sqrt(cfg.head_dim).astype(cfg.dtype)

        attn_weights = jnp.einsum('b h H W d, b h P Q d -> b h H W P Q',
                                  query,
                                  key,
                                  precision=cfg.precision)
        attn_weights = attn_weights + RelativeLogits(config=cfg)(query)
        attn_weights = nn.softmax(attn_weights).astype(cfg.dtype)
        attn_out = jnp.einsum('b h H W P Q, b h H W d -> b H W h d',
                              attn_weights,
                              value,
                              precision=cfg.precision)
        attn_out = rearrange(attn_out, 'b H W h d -> b H W (h d)')
        return attn_out
Beispiel #19
0
    def __call__(self, x, support, eval_mode=False, key=None):
        # Generate a random number generation key if not provided
        if key is None:
            key = jax.random.PRNGKey(int(time.time() * 1e6))

        if not self.inputs_preprocessed:
            x = preprocess_atari_inputs(x)

        hidden_sizes = [32, 64, 64]
        kernel_sizes = [8, 4, 3]
        stride_sizes = [4, 2, 1]
        for hidden_size, kernel_size, stride_size in zip(
                hidden_sizes, kernel_sizes, stride_sizes):
            x = nn.Conv(features=hidden_size,
                        kernel_size=(kernel_size, kernel_size),
                        strides=(stride_size, stride_size),
                        kernel_init=nn.initializers.xavier_uniform())(x)
            x = nn.relu(x)
        x = x.reshape((-1))  # flatten

        net = feature_layer(key, self.noisy, eval_mode=eval_mode)
        x = net(x, features=512)  # Single hidden layer of size 512
        x = nn.relu(x)

        if self.dueling:
            adv = net(x, features=self.num_actions * self.num_atoms)
            value = net(x, features=self.num_atoms)
            adv = adv.reshape((self.num_actions, self.num_atoms))
            value = value.reshape((1, self.num_atoms))
            logits = value + (adv - (jnp.mean(adv, axis=0, keepdims=True)))
        else:
            x = net(x, features=self.num_actions * self.num_atoms)
            logits = x.reshape((self.num_actions, self.num_atoms))

        if self.distributional:
            probabilities = nn.softmax(logits)
            q_values = jnp.sum(support * probabilities, axis=1)
            return atari_lib.RainbowNetworkType(q_values, logits,
                                                probabilities)
        q_values = jnp.sum(logits, axis=1)  # Sum over all the num_atoms
        return atari_lib.DQNNetworkType(q_values)
Beispiel #20
0
    def __call__(self, x):
        x = nn.Conv(features=16, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)

        x = nn.Dense(features=NB_CLASSES)(x)
        x = nn.softmax(x)

        return x
Beispiel #21
0
      def __call__(self, x, support, eval_mode=False, key=None):

        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :self.num_atoms] = onp.arange(
              1, self.num_atoms + 1)
          return to_pick_first_action

        x = x.astype(jnp.float32)
        x = x.reshape((-1))  # flatten
        x = nn.Dense(
            features=self.num_actions * self.num_atoms,
            kernel_init=custom_init,
            bias_init=nn.initializers.ones)(
                x)
        logits = x.reshape((self.num_actions, self.num_atoms))
        if not self.distributional:
          qs = jnp.sum(logits, axis=-1)  # Sum over all the num_atoms
          return atari_lib.DQNNetworkType(qs)
        probabilities = nn.softmax(logits)
        qs = jnp.sum(support * probabilities, axis=1)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Beispiel #22
0
    def __call__(self, x, pos_emb, mask):
        dim_in, h = x.shape[-1], self.heads
        scale = dim_in**-0.5

        norm = nn.LayerNorm()
        to_qkv = nn.Dense(features=self.dim_head * h * 3, use_bias=False)
        to_out = nn.Dense(features=dim_in)

        x = norm(x)
        qkv = np.split(to_qkv(x), 3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, "i (h d) -> i h d", h=h), qkv)

        q = index_update(q, index[1:], apply_rotary_pos_emb(q[1:], pos_emb))
        k = index_update(k, index[1:], apply_rotary_pos_emb(k[1:], pos_emb))

        sim = einsum("i h d, j h d -> i j h", q, k) * scale

        mask = np.pad(mask, (1, 0), constant_values=True)
        mask = rearrange(mask, "j -> () j ()")

        if self.causal:
            i, j = sim.shape[:2]
            tri_mask = np.ones((i - 1, j - 1), dtype=bool)
            tri_mask = np.pad(tri_mask, ((1, 0), (1, 0)),
                              constant_values=False)
            causal_mask = np.triu(tri_mask, j - i + 1)
            causal_mask = rearrange(causal_mask, "i j -> i j ()")
            mask = ~causal_mask * mask

        sim = np.where(mask, sim, LARGE_NEG_VALUE)

        attn = nn.softmax(sim, axis=-2)

        out = einsum("i j h, j h d -> i h d", attn, v)

        out = rearrange(out, "i h d -> i (h d)")
        return to_out(out)
def stable_softmax(x, tau, axis=-1):
    max_x = jnp.amax(x, axis=axis, keepdims=True)
    y = x - max_x
    return nn.softmax(y / tau, axis=axis)
Beispiel #24
0
 def __call__(self, word):
     word_features = self.vocab_layer(word)
     word_features_act = nn.sigmoid(word_features)
     embed_features = self.embed_layer(word_features_act)
     embed_act = nn.softmax(embed_features)
     return embed_act
Beispiel #25
0
    def forward(self, outputs, targets):
        """
        Performs the matching.
        Params:
            outputs: This is a dict that contains at least these entries:
                 "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                 objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4]
                 containing the target box coordinates
        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        bs, num_queries = outputs["logits"].shape[:2]  # B, 100
        # We flatten to compute the cost matrices in a batch
        out_prob = nn.softmax(
            jnp.reshape(outputs["logits"], (bs * num_queries, -1)),
            -1)  # [batch_size * num_queries, num_classes]
        out_bbox = jnp.reshape(
            outputs["pred_boxes"],
            (bs * num_queries, -1))  # [batch_size * num_queries, 4]
        # Also concat the target labels and boxes
        tgt_ids = jnp.concatenate([v["class_labels"]
                                   for v in targets])  # N_tgts
        tgt_bbox = jnp.concatenate([v["boxes"] for v in targets
                                    ])  # N_tgts (B * per batch tgts)
        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        class_cost = -out_prob[:,
                               tgt_ids]  # Out prob is B*Q, num_classes. This gets the -proba[target_class] for each of those heads
        # Compute the L1 cost between boxes
        n_outputs, n_targets = out_bbox.shape[0], tgt_bbox.shape[0]
        # bbox_cost = jnp.linalg.norm(jnp.tile(out_bbox, (n_targets, 1))- jnp.repeat(tgt_bbox, n_outputs, 0), 1, axis=-1) # L1 dist between this BBox and all the tgt bboxs B*NQ, N_tgts
        bbox_cost = jnp.linalg.norm(
            jnp.repeat(out_bbox, n_targets, 0) -
            jnp.tile(tgt_bbox, (n_outputs, 1)),
            1,
            axis=-1
        )  # L1 dist between this BBox and all the tgt bboxs B*NQ, N_tgts
        bbox_cost = bbox_cost.reshape(n_outputs, n_targets)
        giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox),
                                         center_to_corners_format(tgt_bbox))
        # Final cost matrix
        cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
        cost_matrix = jnp.reshape(cost_matrix, (bs, num_queries, -1))
        sizes = [len(v["boxes"]) for v in targets]

        # To replicate torch split, we need the actual chunk indices not the lengths
        chunks = [0]
        for s in sizes[:-1]:
            chunks.append(s + chunks[-1])

        indices = [
            linear_sum_assignment(c[i])
            for i, c in enumerate(cost_matrix.split(chunks[1:], -1))
        ]  # this splits the cost matrix up, i.e if B 1 has 9 labels and 2 has 6, then solves the linear sum assignment

        # returns (i,j) where i is the head idx (out of 100) and j is the target idx out of N_tgts
        return [(jnp.array(i, dtype=jnp.int32), jnp.array(j, dtype=jnp.int32))
                for i, j in indices]
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert len(self.strides) == 3
        q_strides, k_strides, v_strides = self.strides

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        inputs_q = zero_pad_and_reshape(inputs_q)
        inputs_kv = zero_pad_and_reshape(inputs_kv)

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init,
                            bias_init=self.bias_init)

        query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs_kv,
                                             is_training=is_training)

        query = rearrange(query,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)

        if self.talking_heads:
            pre_softmax_transform = self.param(
                'pre_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(attn_scores)

        output = nn.Dropout(rate=self.out_drop_rate)(
            output, deterministic=not is_training)

        return output
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == 4
        assert inputs_kv.ndim == 4
        assert len(self.strides) == 3
        q_strides, k_strides, v_strides = self.strides

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype)

        query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs_kv,
                                             is_training=is_training)

        query = rearrange(query,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query,
                                  key)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights, value)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype)(attn_scores)

        output = nn.Dropout(rate=self.out_dropout_rate)(
            output, deterministic=not is_training)
        return output
    def __call__(
        self,
        encoded_input: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        entity_embeddings: Array,
    ) -> Dict[str, Array]:
        """Perform attention update over entity embedding table.

    Args:
      encoded_input: [batch_size, n_tokens, hidden_size].
      mention_batch_positions: [n_mentions].
      mention_start_positions: [n_mentions].
      mention_end_positions: [n_mentions].
      mention_mask: [n_mentions] attention mask to prevent updates from padding
        mentions.
      entity_embeddings: entity embedding table.

    Returns:
      Updated input, mention encodings and entity attention scores.
    """

        mention_start_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_end_positions))
        mention_encodings = self.mention_query_projector(
            jnp.concatenate((mention_start_encodings, mention_end_encodings),
                            axis=-1))

        scores = jnp.einsum('qd,ed->qe', mention_encodings, entity_embeddings)
        attention_weights = nn.softmax(scores, axis=-1)

        retrieved_values = jnp.einsum('qe,ed->qd', attention_weights,
                                      entity_embeddings)
        retrieved_values = self.entity_projector(retrieved_values)
        retrieved_values = retrieved_values * jnp.expand_dims(mention_mask, -1)

        encoded_input = jut.matmul_2d_index_add(
            encoded_input, (mention_batch_positions, mention_start_positions),
            retrieved_values)
        encoded_input = self.layer_norm(encoded_input)

        # The cosine similarity is computed as dot product divided by norms of
        # both vectors.
        mention_encodings_norm = jnp.linalg.norm(mention_encodings, axis=-1)
        entity_embeddings_norm = jnp.linalg.norm(entity_embeddings, axis=-1)
        cos_scores = scores
        cos_scores = cos_scores / (_SMALL_NUMBER +
                                   jnp.expand_dims(mention_encodings_norm, 1))
        cos_scores = cos_scores / (_SMALL_NUMBER +
                                   jnp.expand_dims(entity_embeddings_norm, 0))

        return {
            'encoded_output': encoded_input,
            'mention_encodings': mention_encodings,
            'cosine_similarity': cos_scores,
            'attention_weights': attention_weights,
        }
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == inputs_kv.ndim == 3

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        query = dense(name='queries')(inputs_q)
        key = dense(name='keys')(inputs_kv)
        value = dense(name='values')(inputs_kv)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)
        if self.talking_heads:
            pre_softmax_transform = self.param('pre_softmax', self.kernel_init,
                                               (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(attn_scores)

        output = nn.Dropout(rate=self.out_drop_rate)(
            output, deterministic=not is_training)

        return output
    def __call__(self, x, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            #classic environments
            x = x.astype(jnp.float32)
            x = x.reshape((-1))

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            x = jax.nn.relu(x)

        if self.dueling:
            adv = net(x, features=self.num_actions * self.num_atoms, rng=rng)
            value = net(x, features=self.num_atoms, rng=rng)
            adv = adv.reshape((self.num_actions, self.num_atoms))
            value = value.reshape((1, self.num_atoms))
            #print('value:', value.shape)
            logits = value + (adv - (jnp.mean(adv, -2, keepdims=True)))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=1)

        else:
            x = net(x, features=self.num_actions * self.num_atoms, rng=rng)
            logits = x.reshape((self.num_actions, self.num_atoms))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=1)

        return atari_lib.RainbowNetworkType(q_values, logits, probabilities)