Example #1
0
    def forward(self, inputs):
        state = self.state
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, jnp.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(
                t=jnp.arange(input_len, dtype=jnp.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[jnp.newaxis, :, :]
            emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        self.state = state
        return inputs
Example #2
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return _multi_device_put(x)
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Example #3
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return jax.device_put_replicated(x, jax.local_devices())
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Example #4
0
  def forward(self, inputs):
    rng, state = self.rng, self.state
    embs = []
    for ax_emb in self.weights:
      ax_emb = jnp.broadcast_to(
          ax_emb, (inputs.shape[0],) + self._shape + (ax_emb.shape[-1],))
      embs.append(ax_emb)

    if self._mode == 'predict':
      assert self._dropout == 0.0
      emb = jnp.concatenate(embs, -1)
      emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
      emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1)
      self.state = state + inputs.shape[1]
      return inputs + emb
    elif self._dropout == 0:
      # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
      # leads to memory blow-up on TPU.
      # emb = jnp.concatenate(embs, -1)
      # return inputs + jnp.reshape(emb, inputs.shape), state
      return inputs + jnp.concatenate(
          [jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1],))
           for emb in embs
          ], -1)
    else:
      emb = jnp.concatenate(embs, -1)
      noise_shape = list(emb.shape)
      for dim in self._dropout_broadcast_dims:
        noise_shape[dim] = 1
      keep_prob = 1.0 - self._dropout
      keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape))
      multiplier = keep.astype(inputs.dtype) / keep_prob
      return inputs + jnp.reshape(emb * multiplier, inputs.shape)
Example #5
0
 def f(x):
     if n_devices > 1 and fastmath.backend_name() == 'jax':
         return _multi_device_put(x)
     elif n_devices > 1:
         return jnp.broadcast_to(x, (n_devices, ) + x.shape)
     else:
         return x
Example #6
0
 def _params(self, inputs):
   """Extracts the mean and std parameters from the inputs."""
   assert inputs.shape[-1] == self.n_inputs
   n_dims = self._n_dims
   # Split the distribution inputs into two parts: mean and std.
   mean = inputs[..., :n_dims]
   if self._learn_std is not None:
     std = inputs[..., n_dims:]
     # Std is non-negative, so let's softplus it.
     std = tl.Softplus()(std + self._std)
   else:
     std = self._std
   # In case of constant or shared std, upsample it to the same dimensionality
   # as the means.
   std = jnp.broadcast_to(std, mean.shape)
   return (mean, std)
Example #7
0
 def _params(self, inputs):
     """Extracts the mean and std parameters from the inputs."""
     if inputs.shape[-1] != self.n_inputs:
         raise ValueError(
             'Invalid distribution parametrization - expected {} parameters, '
             'got {}. Input shape: {}.'.format(self.n_inputs,
                                               inputs.shape[-1],
                                               inputs.shape))
     n_dims = self._n_dims
     # Split the distribution inputs into two parts: mean and std.
     mean = inputs[..., :n_dims]
     if self._learn_std is not None:
         std = inputs[..., n_dims:]
         # Std is non-negative, so let's softplus it.
         std = tl.Softplus()(std + self._std)
     else:
         std = self._std
     # In case of constant or shared std, upsample it to the same dimensionality
     # as the means.
     std = jnp.broadcast_to(std, mean.shape)
     return (mean, std)
Example #8
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, w1, w2, b2 = self.weights
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: check if we need bias and/or put relu after the m1 dot?
        mask_logits = jnp.dot(x, m1)
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        # TODO(lukaszkaiser, chowdhery): Extract this block and share
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        selected_experts = jnp.argmax(log_mask + g * self._temperature,
                                      axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (50% of the batches) use the soft-mask instead of
            # the quantized mask to improve training stability (see the paper above).
            # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
            select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
            quant_mask = jnp.where(select > 0.0, quant_mask, mask)
        else:
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
        quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
        quant_mask_shape = quant_mask.shape
        batch_size = quant_mask.shape[0]

        if self._mode == 'predict' and batch_size == 1:
            # This implementation mimicks inference for batch_size 1.
            start_idx = selected_experts[0] * self._n_elements_in_block
            # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
            w = fastmath.dynamic_slice(
                w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block])
            mid = jnp.dot(x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
            v = fastmath.dynamic_slice(
                w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]])
            v = jnp.reshape(v, [self._n_elements_in_block, -1])
            res = jnp.dot(relu, v) + b2
        else:
            expanded_mask = jnp.broadcast_to(
                quant_mask, (quant_mask_shape[0], quant_mask.shape[1],
                             self._n_elements_in_block))
            expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
            mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Example #9
0
 def representation_mask(mask):
   mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim)))
   return jnp.broadcast_to(
       mask[:, :, None], mask.shape + (serializer.representation_length,)
   )
Example #10
0
 def significance_weights(mask):
   # (repr,) -> (batch, length, repr)
   significance = serializer.significance_map[None, None]
   return mask * decay ** jnp.broadcast_to(significance, mask.shape)