def set_sharding_annotations_v1(task_p: InstantiableParams, mesh_shape: Sequence[int]) -> None: """Sets the sharding annotations in the task config for the given mesh. Args: task_p: The task parameters to update with sharding annotations. mesh_shape: a 3D sequence representing the mesh shape. """ model_p = task_p.model asserts.eq(len(mesh_shape), 3) device_count = np.prod(mesh_shape) device_ids_mesh = np.arange(device_count).reshape(mesh_shape) model_p.device_mesh = device_ids_mesh replica_axis = 'replica' data_axis = 'data' mdl_axis = 'mdl' mesh_axis_names = [replica_axis, data_axis, mdl_axis] task_p.train.inputs_split_mapping = NestedMap( map_1d=((replica_axis, data_axis),), map_2d=((replica_axis, data_axis), None)) model_p.mesh_axis_names = mesh_axis_names if hasattr(model_p, 'lm'): model_p.lm = model_p.lm.cls.set_sharding_params_v1( model_p.lm, replica_axis=replica_axis, data_axis=data_axis, mdl_axis=mdl_axis, device_ids_mesh=device_ids_mesh, mesh_axis_names=mesh_axis_names)
def fprop(self, inputs: JTensor, paddings: Optional[JTensor] = None) -> JTensor: """Apply batch normalization. Args: inputs: The inputs JTensor. Shaped [..., dim]. paddings: The paddings JTensor. Shaped [..., 1]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params inputs, paddings = self._cast_to_fprop_dtype((inputs, paddings)) if paddings is None: paddings = self._get_default_paddings(inputs) asserts.eq(inputs.ndim, paddings.ndim) asserts.eq(paddings.shape[-1], 1) norm_mean, norm_variance, beta, gamma = self.compute_and_update_moments( inputs, paddings) inv = gamma / jnp.sqrt(norm_variance + self._epsilon) bn_output = (inputs - norm_mean) * inv + beta if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def fprop(self, inputs: JTensor, paddings: Optional[JTensor] = None) -> JTensor: """Applies group normalization. Args: inputs: The inputs JTensor. Shaped [batch_size, height, width, channel] if p.rank == 4, else [batch, height, channel]. paddings: The paddings JTensor. Shaped [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: Output after applying group normalization, with the same shape as 'inputs'. Or an output, output_paddings pair if input paddings is not None. """ p = self.params inputs, paddings = self._cast_to_fprop_dtype((inputs, paddings)) asserts.eq(inputs.ndim, p.input_rank) x = jnp.reshape( inputs, list(inputs.shape[:-1]) + [self.num_groups, self.group_size]) expanded_rank = p.input_rank + 1 all_dims = list(range(expanded_rank)) if paddings is None or not p.cumulative: # Skips batch and num_groups. reduce_over_dims = all_dims[1:-2] + all_dims[-1:] else: # Skips batch, seqlen and num_groups. reduce_over_dims = all_dims[2:-2] + all_dims[-1:] if paddings is None and not p.cumulative: group_mean = jnp.mean(x, axis=reduce_over_dims, keepdims=True) group_variance = jnp.mean( jnp.square(x - jax.lax.stop_gradient(group_mean)), axis=reduce_over_dims, keepdims=True) else: expanded_paddings = jnp.reshape( paddings, list(inputs.shape[:2]) + [1] * (expanded_rank - 2)) group_mean, group_variance = compute_moments( x, expanded_paddings, reduce_over_dims, cumulative_axis=1, enable_cross_replica_sum_on_tpu=p. enable_cross_replica_sum_on_tpu, keepdims=True) outputs = self._normalize(x, group_mean, group_variance) if paddings is None: return outputs else: return outputs, paddings
def compute_moments( inputs: JTensor, padding: JTensor, reduce_over_dims: List[int], cumulative_axis: Optional[int] = None, enable_cross_replica_sum_on_tpu: bool = False, keepdims: bool = False, ) -> Tuple[JTensor, JTensor]: """Computes mean and variance over the valid data points in inputs. Args: inputs: The inputs JTensor. padding: The paddings JTensor. reduce_over_dims: A sequence of ints for dimensions to reduce `inputs` over. cumulative_axis: An optional int for axis to compute a cumulative sum. If none, there will be no cumulative sum applied. enable_cross_replica_sum_on_tpu: A boolean indicating whether to use an all-reduce sum over the 'batch' axis. keepdims: A boolean indicating whether summations reduction axes should be left in the result as dimensions with size one. Returns: Tuple of (mean, variance). """ asserts.eq(inputs.ndim, padding.ndim) rank = inputs.ndim for dim in reduce_over_dims: asserts.between(dim, 0, rank, left_strict=False, right_strict=True) mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=keepdims) count_v = jnp.sum(jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_v = jnp.cumsum(sum_v, axis=cumulative_axis) count_v = jnp.cumsum(count_v, axis=cumulative_axis) if enable_cross_replica_sum_on_tpu: # TODO(shafey, yonghui): Fetch axis_name from globals. sum_v = jax.lax.psum(sum_v, axis_name='batch') count_v = jax.lax.psum(count_v, axis_name='batch') count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = jnp.sum((inputs - mean) * (inputs - mean) * mask, axis=reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_vv = jnp.cumsum(sum_vv, axis=cumulative_axis) if enable_cross_replica_sum_on_tpu: # TODO(shafey, yonghui): Fetch axis_name from globals. sum_vv = jax.lax.psum(sum_vv, axis_name='batch') variance = sum_vv / count_v return mean, variance
def save_checkpoint( train_state: train_states.TrainState, checkpoint_dir: str, overwrite: bool = False, unreplicate: bool = True, checkpoint_type: CheckpointType = CheckpointType.CHECKPOINT_FLAX, state_specs: Optional[train_states.TrainState] = None, ) -> None: """Saves a checkpoint into the provided base directory. This is typically called on a replicated TrainState instance. Args: train_state: The TrainState instance to save. checkpoint_dir: The base directory from where to retrieve checkpoints. overwrite: Whether to overwrite existing checkpoints files if a checkpoint at the current or a later step already exists. unreplicate: Whether to unreplicate variables (Optional). If using SPMD sharding, then this should be set to False. checkpoint_type: The checkpoint type (implementation) to save. Either `CHECKPOINT_FLAX`, `CHECKPOINT_MULTI_HOST_FLAX`, `CHECKPOINT_GDA` or `CHECKPOINT_PERSISTENCE`. state_specs: Currently unused. Raises: ValueError: If the global step has an unexpected shape, if `state_specs` is not specified for persistence-based checkpointing or if `checkpoint_type` is invalid. """ del state_specs if jax.config.jax_parallel_functions_output_gda: asserts.eq(checkpoint_type, CheckpointType.CHECKPOINT_GDA) step = int(jax.device_get(py_utils.maybe_unreplicate_gda(train_state.step))) _save_checkpoint_gda(train_state, checkpoint_dir, overwrite, step) return if train_state.step.ndim == 0: step = jax.device_get(train_state.step) elif train_state.step.ndim == 1: step = jax.device_get(train_state.step[0]) else: raise ValueError( f'Expecting a replicated 1D global step (got `{train_state.step.ndim}`).' ) if checkpoint_type in { CheckpointType.CHECKPOINT_FLAX, CheckpointType.CHECKPOINT_MULTI_HOST_FLAX }: use_multi_host = ( checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX) _save_checkpoint_flax(train_state, checkpoint_dir, overwrite, unreplicate, step, use_multi_host) else: raise ValueError(f'Unexpected checkpoint_type `{checkpoint_type}`.')
def retrieve_checkpoint_type(multi_host_checkpointing: bool, maybe_use_persistence_checkpointing, task_p: InstantiableParams) -> CheckpointType: """Retrieves the CheckpointType given the input arguments.""" if jax.config.jax_parallel_functions_output_gda: asserts.eq(multi_host_checkpointing, True) checkpoint_type = CheckpointType.CHECKPOINT_GDA elif maybe_use_persistence_checkpointing and task_p.model.device_mesh is not None: asserts.eq(multi_host_checkpointing, False) checkpoint_type = CheckpointType.CHECKPOINT_PERSISTENCE else: # Flax-based checkpointing if multi_host_checkpointing: checkpoint_type = CheckpointType.CHECKPOINT_MULTI_HOST_FLAX else: checkpoint_type = CheckpointType.CHECKPOINT_FLAX return checkpoint_type
def __init__(self, params): """Initializes GroupNorm layer and checks parameters.""" super().__init__(params) p = self.params asserts.not_none(p.name) asserts.gt(p.num_groups, 0) asserts.gt(p.min_group_size, 0) asserts.le(p.min_group_size, p.dim) asserts.eq(p.dim % p.min_group_size, 0) if p.dim >= p.num_groups: asserts.eq( p.dim % p.num_groups, 0, msg='p.dim({0}) is not dividable by p.num_groups({1})'.format( p.dim, p.num_groups)) asserts.in_set(p.input_rank, (3, 4))
def restore_checkpoint( train_state: Optional[train_states.TrainState], checkpoint_dir: str, global_mesh: Optional[maps.Mesh] = None, checkpoint_type: CheckpointType = CheckpointType.CHECKPOINT_FLAX, state_specs: Optional[train_states.TrainState] = None, step: Optional[int] = None) -> train_states.TrainState: """Restores a checkpoint from the provided base directory. This is typically called on an unreplicated TrainState instance. Args: train_state: The TrainState instance to restore. checkpoint_dir: The base directory from where to retrieve checkpoints. global_mesh: The global mesh representing devices across multiple processes. checkpoint_type: The checkpoint type (implementation) to restore. Either `CHECKPOINT_FLAX`, `CHECKPOINT_MULTI_HOST_FLAX`, `CHECKPOINT_GDA` or `CHECKPOINT_PERSISTENCE`. state_specs: If using a GDA-based checkpoint, the partition specs corresponding to this TrainState instance to restore. step: Step number to load a checkpoint from or None to load the latest. Returns: A restored `TrainState` instance. Raises: ValueError: When a mismatch between the current checkpoint structure and the saved checkpoint one is detected. """ if jax.config.jax_parallel_functions_output_gda: asserts.eq(checkpoint_type, CheckpointType.CHECKPOINT_GDA) return _restore_checkpoint_gda(train_state, checkpoint_dir, global_mesh, state_specs, step) if train_state is not None and train_state.step.ndim != 0: raise ValueError('Expecting an unreplicated scalar global step (got ' f'`{train_state.step.ndim}`).') if checkpoint_type in { CheckpointType.CHECKPOINT_FLAX, CheckpointType.CHECKPOINT_MULTI_HOST_FLAX }: return _restore_checkpoint_flax(train_state, checkpoint_dir, step) else: raise ValueError(f'Unexpected checkpoint_type `{checkpoint_type}`.')
def fprop(self, state0: NestedMap, inputs: NestedMap) -> Tuple[NestedMap, NestedMap]: """Forward function. `_reset_state` is optionally applied if `reset_cell_state` is True. The RNN layer should provide `reset_mask` inputs in addition to other inputs. `reset_mask` inputs are expected to be 0 at timesteps where state0 should be reset to default (zeros) before running `fprop`, and 1 otherwise. This is meant to support use cases like packed inputs, where multiple samples are fed in a single input example sequence, and need to be masked from each other. For example, if the two examples packed together are ['good', 'day'] -> ['guten-tag'] and ['thanks'] -> ['danke'] to produce ['good', 'day', 'thanks'] -> ['guten-tag', 'danke'], the source reset_masks would be [1, 1, 0] and target reset masks would be [1, 0]. These ids are meant to enable masking computations for different examples from each other. Args: state0: The previous recurrent state. inputs: The inputs to the cell. Returns: state1: The next recurrent state. extras: Intermediate results to faciliate backprop. """ asserts.instance(inputs.act, list) asserts.eq(self.params.inputs_arity, len(inputs.act)) state0 = self._maybe_reset_state(state0, inputs) concat = jnp.concatenate(inputs.act + [state0.m], 1) wm = self.local_theta().wm xmw = jnp.einsum('bd,dc->bc', concat, wm) i_i, i_g, f_g, o_g = self._retrieve_and_split_gates(xmw) state1 = self._gates_internal(state0, inputs, i_i, i_g, f_g, o_g) return state1, NestedMap()
def test_eq_raises(self, value1, value2): with self.assertRaisesRegex( ValueError, f'`value1={value1}` must be equal to `value2={value2}`.$'): asserts.eq(value1, value2) with self.assertRaisesRegex( ValueError, f'`custom_value={value1}` must be equal to `value2={value2}`.$' ): asserts.eq(value1, value2, value_str1=f'custom_value={value1}') with self.assertRaisesRegex( ValueError, f'`value1={value1}` must be equal to `custom_value={value2}`.$' ): asserts.eq(value1, value2, value_str2=f'custom_value={value2}') custom_error_msg = 'This is a custom error message.' with self.assertRaisesRegex(ValueError, f'{custom_error_msg}$'): asserts.eq(value1, value2, msg=custom_error_msg)
def test_eq(self, value1): value2 = value1 asserts.eq(value1, value2)
def ctc_loss(logits: JTensor, logitpaddings: JTensor, labels: JTensor, labelpaddings: JTensor, blank_id: int = 0) -> Tuple[JTensor, Mapping[str, JTensor]]: """Computes CTC loss. This function performs forward computation over an FSA with `N * 2` states where `N` is the max number of labels. The states are split into two groups: Phi states and emission states. a phi-state accepts repetition of phi (blank)-symbols and transits to emission state when the correct label is observed. An emission state accepts repetition of the label and transits to the next phi states at any time (so called epsilon-transition). Below, `B` denotes the batch size, `T` denotes the time steps in `logits`, and `N` denotes the time steps in `labels`. Args: logits: (B, T, K)-array containing log-probabilities of each class. logitpaddings: (B, T)-array. Padding indicators for `logits`. labels: (B, N)-array containing reference integer labels. labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently, `labels` must be right-padded, i.e. each row of `labelpaddings` must be repetition of zeroes, followed by repetition of ones. blank_id: Id for blank token. Returns: A pair of `(per_seq_loss, aux)`. per_seq_loss: (B,)-array containing loss values for each sequence in the batch. aux: Dictionary containing interim variables used for computing losses. aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each phi-state corresponding to the n-th label. aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each emission-state corresponding to the n-th label. aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol corresponding to each time frame. aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label corresponding to each time frame. """ batchsize, unused_maxinputlen, num_classes = logits.shape batchsize_, maxlabellen = labels.shape asserts.eq(batchsize, batchsize_) logprobs = jax.nn.log_softmax(logits) labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32) # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) repeat = jnp.pad(repeat, ((0, 0), (0, 1))) logprobs_phi = logprobs[:, :, blank_id:blank_id + 1] # [B, T, 1] logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot) logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] logalpha_phi_init = jnp.ones( (batchsize, maxlabellen + 1)) * _LOGEPSILON # [B, N] logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) logalpha_emit_init = jnp.ones( (batchsize, maxlabellen)) * _LOGEPSILON # [B, N] def loop_body(prev, x): prev_phi, prev_emit = prev # emit-to-phi epsilon transition, except if the next label is repetition prev_phi_orig = prev_phi prev_phi = prev_phi.at[:, 1:].set( jnp.logaddexp(prev_phi[:, 1:], prev_emit + _LOGEPSILON * repeat)) logprob_emit, logprob_phi, pad = x # phi-to-emit transition next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit) # self-loop transition next_phi = prev_phi + logprob_phi # emit-to-phi blank transition only when the next label is repetition next_phi = next_phi.at[:, 1:].set( jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + _LOGEPSILON * (1.0 - repeat))) pad = pad.reshape((batchsize, 1)) next_emit = pad * prev_emit + (1.0 - pad) * next_emit next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi return (next_phi, next_emit), (next_phi, next_emit) xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0))) _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) # last row needs to be updated with the last epsilon transition logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set( jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1])) logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) # extract per_seq_loss one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot) return per_seq_loss, { 'logalpha_phi': logalpha_phi, 'logalpha_emit': logalpha_emit, 'logprobs_phi': logprobs_phi, 'logprobs_emit': logprobs_emit, }