예제 #1
0
 def annotate(x):
     unconstrained_dims = list(range(1, x.ndim))
     dims_mapping = (p.weight_split_dims_mapping.stages + [None] *
                     (x.ndim - 1))
     return base_layer.maybe_shard(x, dims_mapping,
                                   p.mesh_axis_names,
                                   unconstrained_dims)
예제 #2
0
    def get_logits(self, inputs: JTensor) -> JTensor:
        """Returns logits given the inputs with an option to cap it.

    Args:
      inputs: a single JTensor with shape [..., input_dim].

    Returns:
      logits: with shape [..., num_classes]. Unnormalized softmax's logits.
    """
        p = self.params
        ap = p.activation_split_dims_mapping
        # activations are scaled with 1/sqrt(input_dims)
        inputs *= (p.input_dims**-0.5)
        # VH -> HV
        softmax_var = jnp.transpose(self.embedding.local_theta().w)
        # Compute logits:  BLH,HV -> BLV
        logits = linears.project_last_dim(inputs, softmax_var)
        logits = base_layer.maybe_shard(logits, ap.out, p.mesh_axis_names)

        # Soft cap logits if applicable
        if p.soft_cap_logits:
            logits = p.soft_cap_logits * jnp.tanh(logits / p.soft_cap_logits)

        # abs cap logits if applicable
        if p.logits_abs_max:
            logits = jnp.clip(logits, -p.logits_abs_max, p.logits_abs_max)
        return logits
예제 #3
0
def reshard_input_based_on_rank_fn(
    mapping_dict: Dict[str, base_layer.SplitDimsMapping],
    mesh_names: Sequence[str],
    x: JTensor,
) -> JTensor:
  """Reshards input based on its rank.

  Args:
    mapping_dict: Dictionary which contains the split mapping for different
      shapes. For n-d shape, it must have an entry f'map_{n}d' which tells us
      how to partition tensors of this dimension.
    mesh_names: List of mesh axis names.
    x: JTensor which to shard.

  Returns:
    Resharded tensor.
  """
  key = f'map_{len(x.shape)}d'
  if key not in mapping_dict:
    raise ValueError(f'Split mapping must be provided for {len(x.shape)}-d'
                     f'in the form of key map_{len(x.shape)} in'
                     f'{mapping_dict}.')
  if mapping_dict[key] is not None:
    return base_layer.maybe_shard(x, mapping_dict[key], mesh_names)
  else:
    return x
예제 #4
0
 def emb_lookup(self, ids: JTensor) -> JTensor:
     p = self.params
     ap = p.activation_split_dims_mapping
     # BL -> BLV
     one_hot_ids = jax.nn.one_hot(ids,
                                  p.num_classes,
                                  dtype=self.fprop_dtype)
     # BLV,VH -> BLH
     embs = linears.project_last_dim(one_hot_ids,
                                     self.embedding.local_theta().w)
     embs = base_layer.maybe_shard(embs, ap.emb_out_split_dims_mapping,
                                   p.mesh_axis_names)
     return embs
예제 #5
0
  def fprop(self, inputs: JTensor) -> JTensor:
    """Apply projection to inputs.

    Args:
      inputs: The inputs JTensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
    p = self.params
    theta = self.local_theta()
    ap = p.activation_split_dims_mapping
    out = project_last_dim(inputs, theta.w)
    out = base_layer.maybe_shard(out, ap.out, p.mesh_axis_names)
    return out
예제 #6
0
    def emb_lookup(self, ids: JTensor) -> JTensor:
        p = self.params
        ap = p.activation_split_dims_mapping
        emb_var = jnp.transpose(self.logits_ffn.linear.local_theta().w)
        if p.lookup_style == 'index':
            embs = jnp.asarray(emb_var)[(ids, )]
        elif p.lookup_style == 'matmul':
            # Explicit casting to fprop_dtype needed for bf16.
            one_hot_ids = jax.nn.one_hot(ids,
                                         p.num_classes,
                                         dtype=self.fprop_dtype)
            embs = linears.project_last_dim(one_hot_ids, emb_var)
        else:
            raise ValueError('Unknown lookup style.')
        # Scale with sqrt(embedding dims)
        if p.scale_sqrt_depth:
            embs *= p.input_dims**0.5

        embs = base_layer.maybe_shard(embs, ap.emb_out_split_dims_mapping,
                                      p.mesh_axis_names)
        return embs