Exemple #1
0
    def update_state(self, inputs):
        cache, idx = self.state
        cache = fastmath.dynamic_update_slice_in_dim(
            cache,
            inputs, (idx + self._shift) % (2 * self._total_kv_pooling),
            axis=1)

        if self._sliding:
            cache = fastmath.dynamic_update_slice_in_dim(
                cache,
                inputs, (idx + self._total_kv_pooling * 2 - 1) %
                (2 * self._total_kv_pooling),
                axis=1)

        if self._sliding:
            left_index = idx % self._total_kv_pooling
        else:
            left_index = (idx - (idx % self._total_kv_pooling)) % \
                         (2 * self._total_kv_pooling)

        output = fastmath.dynamic_slice(
            cache, [0, left_index, 0],
            [cache.shape[0], self._total_kv_pooling, cache.shape[2]])

        self.state = cache, idx + self._n_raw_tokens_generated
        return output
Exemple #2
0
 def _UpdateRow(x):
     # (L, H), (L1, H) & (L2, H)
     row_ed, row_e, _ = x
     mask_e = row_e != 0
     len_e = jnp.sum(mask_e, dtype=jnp.int32)
     # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
     # and pick up (L2, H) tensor slice from there.
     zero = jnp.array(0, dtype=len_e.dtype)  # avoid int32/int64 mismatch
     return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))
Exemple #3
0
    def forward(self, inputs):
        if self._mode != 'predict':
            return inputs

        output = fastmath.dynamic_slice(
            inputs, [0, self.state, 0],
            [inputs.shape[0], self._n_raw_tokens_generated, inputs.shape[2]])
        self.state = (self.state +
                      self._n_raw_tokens_generated) % self._total_kv_pooling
        return output
Exemple #4
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, w1, w2, b2 = self.weights
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: check if we need bias and/or put relu after the m1 dot?
        mask_logits = jnp.dot(x, m1)
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        # TODO(lukaszkaiser, chowdhery): Extract this block and share
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        selected_experts = jnp.argmax(log_mask + g * self._temperature,
                                      axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (50% of the batches) use the soft-mask instead of
            # the quantized mask to improve training stability (see the paper above).
            # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
            select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
            quant_mask = jnp.where(select > 0.0, quant_mask, mask)
        else:
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
        quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
        quant_mask_shape = quant_mask.shape
        batch_size = quant_mask.shape[0]

        if self._mode == 'predict' and batch_size == 1:
            # This implementation mimicks inference for batch_size 1.
            start_idx = selected_experts[0] * self._n_elements_in_block
            # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
            w = fastmath.dynamic_slice(
                w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block])
            mid = jnp.dot(x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
            v = fastmath.dynamic_slice(
                w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]])
            v = jnp.reshape(v, [self._n_elements_in_block, -1])
            res = jnp.dot(relu, v) + b2
        else:
            expanded_mask = jnp.broadcast_to(
                quant_mask, (quant_mask_shape[0], quant_mask.shape[1],
                             self._n_elements_in_block))
            expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
            mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed