def annotate(x): unconstrained_dims = list(range(1, x.ndim)) dims_mapping = (p.weight_split_dims_mapping.stages + [None] * (x.ndim - 1)) return base_layer.maybe_shard(x, dims_mapping, p.mesh_axis_names, unconstrained_dims)
def get_logits(self, inputs: JTensor) -> JTensor: """Returns logits given the inputs with an option to cap it. Args: inputs: a single JTensor with shape [..., input_dim]. Returns: logits: with shape [..., num_classes]. Unnormalized softmax's logits. """ p = self.params ap = p.activation_split_dims_mapping # activations are scaled with 1/sqrt(input_dims) inputs *= (p.input_dims**-0.5) # VH -> HV softmax_var = jnp.transpose(self.embedding.local_theta().w) # Compute logits: BLH,HV -> BLV logits = linears.project_last_dim(inputs, softmax_var) logits = base_layer.maybe_shard(logits, ap.out, p.mesh_axis_names) # Soft cap logits if applicable if p.soft_cap_logits: logits = p.soft_cap_logits * jnp.tanh(logits / p.soft_cap_logits) # abs cap logits if applicable if p.logits_abs_max: logits = jnp.clip(logits, -p.logits_abs_max, p.logits_abs_max) return logits
def reshard_input_based_on_rank_fn( mapping_dict: Dict[str, base_layer.SplitDimsMapping], mesh_names: Sequence[str], x: JTensor, ) -> JTensor: """Reshards input based on its rank. Args: mapping_dict: Dictionary which contains the split mapping for different shapes. For n-d shape, it must have an entry f'map_{n}d' which tells us how to partition tensors of this dimension. mesh_names: List of mesh axis names. x: JTensor which to shard. Returns: Resharded tensor. """ key = f'map_{len(x.shape)}d' if key not in mapping_dict: raise ValueError(f'Split mapping must be provided for {len(x.shape)}-d' f'in the form of key map_{len(x.shape)} in' f'{mapping_dict}.') if mapping_dict[key] is not None: return base_layer.maybe_shard(x, mapping_dict[key], mesh_names) else: return x
def emb_lookup(self, ids: JTensor) -> JTensor: p = self.params ap = p.activation_split_dims_mapping # BL -> BLV one_hot_ids = jax.nn.one_hot(ids, p.num_classes, dtype=self.fprop_dtype) # BLV,VH -> BLH embs = linears.project_last_dim(one_hot_ids, self.embedding.local_theta().w) embs = base_layer.maybe_shard(embs, ap.emb_out_split_dims_mapping, p.mesh_axis_names) return embs
def fprop(self, inputs: JTensor) -> JTensor: """Apply projection to inputs. Args: inputs: The inputs JTensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params theta = self.local_theta() ap = p.activation_split_dims_mapping out = project_last_dim(inputs, theta.w) out = base_layer.maybe_shard(out, ap.out, p.mesh_axis_names) return out
def emb_lookup(self, ids: JTensor) -> JTensor: p = self.params ap = p.activation_split_dims_mapping emb_var = jnp.transpose(self.logits_ffn.linear.local_theta().w) if p.lookup_style == 'index': embs = jnp.asarray(emb_var)[(ids, )] elif p.lookup_style == 'matmul': # Explicit casting to fprop_dtype needed for bf16. one_hot_ids = jax.nn.one_hot(ids, p.num_classes, dtype=self.fprop_dtype) embs = linears.project_last_dim(one_hot_ids, emb_var) else: raise ValueError('Unknown lookup style.') # Scale with sqrt(embedding dims) if p.scale_sqrt_depth: embs *= p.input_dims**0.5 embs = base_layer.maybe_shard(embs, ap.emb_out_split_dims_mapping, p.mesh_axis_names) return embs