def get_tensor_and_scale_for_act(act, hparams, get_bounds_params): # We check whether activations should be quantized based on 'hparams'. If # so, we quantize it. If not, we return it unchanged. In either case, we # return a scale factor appropriate for unscaling the result of the # lax.dot_general. if hparams is not None and hparams.prec is not None: quant_op = QuantOps.create_input_ops( act, hparams=hparams, get_bounds_params=get_bounds_params) scale = quant_op.get_scale_for_aqt( allow_per_channel_scales=False) # Since only per-layer scale factors are supported, we assert that the # scale factors are scalars. shape_utils.assert_shapes_compatible(scale.shape, ()) # TODO(malmaud): See comment on 'act_op.to_quantized' earlier in this # file, which applies here as well. act_quantized = quant_op.to_quantized(act, dtype=input_dtype) # TODO(shivaniagrawal): See comment in 'dot_general' above on why this # logic is duplicated here and in the 'else' block below. return lax.cond( quant_op.should_quantize(), # lambda _: (act_quantized, scale), # lambda _: (act, jnp.array(1.0, dtype=SCALE_DTYPE)), # None) else: # To avoid having a separate code path for every possibility of which of # the two input tensors are quantized , we implement not quantizing an # activation tensor by simply setting its corresponding scale factor to # 1.0. return act, jnp.array(1.0, dtype=SCALE_DTYPE)
def __call__( self, x, *, mask, ): """Applies a tag to track distributions. Args: x: the array to compute statistics distributions over. mask: boolean array indicating which elements of 'x' should be included in the stats calculation ('True' means to include). Returns: x unchanged. The return value can also be ignored. """ if mask is None: mask = jnp.full(x.shape, True) shape_utils.assert_shapes_compatible(x.shape, mask.shape) mask = jnp.broadcast_to(mask, x.shape) channel_axis = self.channel_axis if channel_axis is not None: if not isinstance(channel_axis, Iterable): channel_axis = (channel_axis, ) channel_axis = normalize_axes(channel_axis, x.ndim) x = _take_subset_of_axes( x, axis=channel_axis, num_indices_per_ax=self.num_indices_per_ax) mask = _take_subset_of_axes( mask, axis=channel_axis, num_indices_per_ax=self.num_indices_per_ax) reduction_axis = tuple( [ax for ax in range(x.ndim) if ax not in channel_axis]) else: reduction_axis = None distr_shape = () if channel_axis: distr_shape = tuple(d for i, d in enumerate(x.shape) if i in channel_axis) # TODO(wanglisa): Consider adding configurability to specify which # statistics are collected. init_with_zeros = lambda shape: jnp.zeros(shape, dtype=jnp.float32) is_initializing = not self.has_variable('stats_tag', 'min_per_ch') min_per_ch = self.variable( 'stats_tag', 'min_per_ch', init_with_zeros, distr_shape, ) max_per_ch = self.variable('stats_tag', 'max_per_ch', init_with_zeros, distr_shape) mean_per_ch = self.variable( 'stats_tag', 'mean_per_ch', init_with_zeros, distr_shape, ) stddev_per_ch = self.variable( 'stats_tag', 'stddev_per_ch', init_with_zeros, distr_shape, ) absdev_per_ch = self.variable( 'stats_tag', 'absdev_per_ch', init_with_zeros, distr_shape, ) stddev_per_ch_uncentered = self.variable( 'stats_tag', 'stddev_per_ch_uncentered', init_with_zeros, distr_shape, ) absdev_per_ch_uncentered = self.variable( 'stats_tag', 'absdev_per_ch_uncentered', init_with_zeros, distr_shape, ) if self.update_stats and not is_initializing: min_per_ch.value = jnp.min(jnp.where(mask, x, math.inf), axis=reduction_axis) max_per_ch.value = jnp.max(jnp.where(mask, x, -math.inf), axis=reduction_axis) mean_per_ch_keepdims = stats.masked_mean(x, mask=mask, axis=reduction_axis, paxis_name=None, keepdims=True) mean_per_ch.value = mean_per_ch_keepdims.squeeze( axis=reduction_axis) stddev_per_ch.value = jnp.sqrt( stats.masked_mean((x - mean_per_ch_keepdims)**2, mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False)) absdev_per_ch.value = stats.masked_mean( jnp.abs(x - mean_per_ch_keepdims), mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False) stddev_per_ch_uncentered.value = jnp.sqrt( stats.masked_mean(jnp.square(x), mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False)) absdev_per_ch_uncentered.value = stats.masked_mean( jnp.abs(x), mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False)
def __call__( self, x, *, bounds_params, ): """Compute the input batch statistics. Args: x: the input to get bounds from using statistics. bounds_params: parameters to compute input's statistics and bounds. Returns: Bound value (same shape as inputs). """ if bounds_params.mask is not None: shape_utils.assert_shapes_compatible(x.shape, bounds_params.mask.shape) x = jnp.asarray(x, jnp.float32) hyper = self.hyper is_initializing = not self.has_variable('get_bounds', 'stats') if hyper.granularity == quant_config.QuantGranularity.per_tensor: # Equivalently, this could be written as # quant_axis = tuple(range(x.ndim)) quant_axis = None stats_shape = (1, ) * len(x.shape) elif hyper.granularity == quant_config.QuantGranularity.per_channel: # Quantize by aggregating activation statistics across all dimensions of # the activation tensor EXCEPT the last dimension, which we interpret as # the channel dimension. For example, in a transformer context, x might # have a shape corresponding to [example, token, channel], in which case # this aggregates activation statistics separately for each feature, where # for each feature it aggregates over all unmasked tokens in all examples. quant_axis = tuple(range(x.ndim - 1)) stats_shape = (1, ) * (x.ndim - 1) + (x.shape[-1], ) else: raise ValueError(f'Unknown granularity {hyper.granularity}') stats_state = self.variable('get_bounds', 'stats', Stats.stats_initializer, stats_shape) def bound_initializer(shape): return hyper.initial_bound * jnp.ones(shape) bounds = self.variable('get_bounds', 'bounds', bound_initializer, stats_shape) if bounds_params.update_stats and not is_initializing: stats_state.value = Stats.create_updated_stats( stats_state.value, x, mask=bounds_params.mask, axis=quant_axis, paxis_name=bounds_params.paxis_name, alpha=hyper.ema_coeff, exclude_zeros=hyper.exclude_zeros) if bounds_params.update_bounds and not is_initializing: bounds.value = self._stats_to_bounds(stats_state.value) if hyper.reset_stats: stats_state.value = Stats.stats_initializer(stats_shape) return bounds.value
def __call__(self, inputs_q, inputs_kv, *, padding_mask, key_padding_mask, segmentation=None, key_segmentation=None): """Applies multi-head dot product attention on the input data. If weight_prec is not None, scales and quantizes weights to signed int with weight_prec bits. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` or for self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. padding_mask: boolean tensor specifying query tokens that are pad token. key_padding_mask: boolean tensor specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ batch_size, query_sequence_length, channel_size = inputs_q.shape hparams = self.hparams if inputs_kv is None: inputs_kv = inputs_q key_sequence_length = inputs_q.shape[1] else: key_sequence_length = inputs_kv.shape[1] shape_utils.assert_shapes_equal( inputs_kv.shape, (batch_size, key_sequence_length, channel_size)) jax_precision = jax.lax.Precision.DEFAULT if padding_mask is not None: shape_utils.assert_shapes_equal( padding_mask.shape, (batch_size, query_sequence_length, 1)) if key_padding_mask is None: key_padding_mask = padding_mask else: shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, key_sequence_length, 1)) attention_axis = self.attention_axis if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) qkv_features = self.qkv_features qkv_features = qkv_features or inputs_q.shape[-1] num_heads = self.num_heads assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads paxis_name = self.paxis_name train = self.train kernel_init = self.kernel_init bias_init = self.bias_init use_bias = self.use_bias dtype = self.dtype def multi_batch_dense_aqt(inputs, *, name, padding_mask): batch_size, sequence_length, channel_size = inputs.shape inputs = inputs.reshape(batch_size * sequence_length, channel_size) if padding_mask is not None: padding_mask = padding_mask.reshape( batch_size * sequence_length, 1) out = flax_layers.DenseAqt(name=name, features=num_heads * head_dim, paxis_name=paxis_name, train=train, quant_context=self.quant_context, hparams=hparams.dense_kqv, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype)(inputs, padding_mask=padding_mask) return out.reshape(batch_size, sequence_length, num_heads, head_dim) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, sequence_length, n_heads, n_features_per_head] query = multi_batch_dense_aqt(inputs_q, name='query', padding_mask=padding_mask) key = multi_batch_dense_aqt(inputs_kv, name='key', padding_mask=key_padding_mask) value = multi_batch_dense_aqt(inputs_kv, name='value', padding_mask=key_padding_mask) is_cache_initialized = False if self.decode: is_cache_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_cache_initialized: expected_shape = list(cached_key.value.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cached_key.value.shape indices = [0] * len(cshape) i = cache_index.value attn_size = onp.prod(onp.take(cshape, attention_axis)) *batch_dims, max_length, num_heads, depth_per_head = ( # pylint: disable=unused-variable cached_key.value.shape) indices = (0, ) * len(batch_dims) + (i, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) one = jnp.array(1, jnp.int32) cache_index.value = cache_index.value + one cached_key.value = key cached_value.value = value # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(max_length) < cache_index.value), cshape[:2]) key_padding_mask = key_padding_mask.astype( jnp.float32)[Ellipsis, None] # create attention masks mask_components = [] if self.causal_mask: if self.decode and is_cache_initialized: bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.int32) mask = ii < cache_index.value mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask attn_padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(attn_padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) attention_mask = None if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) attention_mask = attention_mask.astype(jnp.bool_) # attention mask in the form of attention bias attention_bias = jnp.where( attention_mask, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # Add an extra dimension to the mask corresponding to the head # dimension. eg, if inputs_q has shape [batch_size, sequence_length, # n_features], then padding_mask will have a shape # [batch_size, sequence_length, 1] and query will have shape # [batch_size, sequence_length, n_heads, n_features_per_head]. # We create query_padding_mask with shape [batch_size, sequence_length, # 1, 1] to be broadcast-compatible with 'query'. if padding_mask is not None: padding_mask = padding_mask[Ellipsis, None] shape_utils.assert_shapes_equal( padding_mask.shape, (batch_size, query_sequence_length, 1, 1)) if key_padding_mask is not None: key_padding_mask = key_padding_mask[Ellipsis, None] # During prediction, the key padding mask is only going to be # broadcast-compatible with the key. shape_utils.assert_shapes_compatible( key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1)) # apply attention attention_fn = self.attention_fn dropout_rate = self.dropout_rate broadcast_dropout = self.broadcast_dropout deterministic = self.deterministic if not deterministic and self.dropout_rate > 0.0: dropout_rng = self.make_rng('dropout') else: dropout_rng = None x = attention_fn( # pylint: disable=redundant-keyword-arg query=query, key=key, value=value, hparams=hparams.attn_acts, paxis_name=paxis_name, train=train, quant_context=self.quant_context, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=jax_precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, query_padding_mask=padding_mask, key_padding_mask=key_padding_mask, attn_mask=attention_mask) shape_utils.assert_shapes_equal( x.shape, (batch_size, query_sequence_length, num_heads, head_dim)) x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim) if padding_mask is not None: padding_mask = padding_mask.reshape( batch_size * query_sequence_length, 1) # back to the original inputs dimensions out = flax_layers.DenseAqt(features=channel_size, hparams=hparams.dense_out, quant_context=self.quant_context, paxis_name=paxis_name, train=train, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype, name='dense_out')(x, padding_mask=padding_mask) shape_utils.assert_shapes_equal( out.shape, (batch_size * query_sequence_length, channel_size)) out = out.reshape(batch_size, query_sequence_length, channel_size) return out
def dot_product_attention(query, key, value, hparams, quant_context, paxis_name, train, key_padding_mask, query_padding_mask, attn_mask, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. Args: query: queries for calculating attention with shape of `[batch_size, sequence_length, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, sequence_length, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, sequence_length, num_heads, value_channels]`. hparams: hyperparameters used for quantization. quant_context: context for quantization. paxis_name: axis_name to which a user `pmaps` the parent module (model), refer to jax.pmap() for more documentation. This arg is used for get_bounds acts quantization (QuantOps.create_input_fake_quant) train: Whether model is training. key_padding_mask: boolean mask indicating which elements in 'key' and 'value' are padding. Must have a shape compatible with 'key' and 'value'. query_padding_mask: boolean mask indicating which elements in `query` are padding (True means not padding). attn_mask: boolean mask indicating which elements of the calculated attention weight matrix should be used for collecting activation statistics. Should have a shape broadcast-compatible with '[bs, sequence_length, sequence_length]'. Must have a shape broadcast-compatible 'query'. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout. dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, sequence_length, num_heads, value_channels]`. """ batch_size, query_sequence_length, num_heads, channel_size = query.shape key_sequence_length = key.shape[1] shape_utils.assert_shapes_equal( key.shape, (batch_size, key_sequence_length, num_heads, channel_size)) shape_utils.assert_shapes_equal( value.shape, (batch_size, key_sequence_length, num_heads, channel_size)) if key_padding_mask is not None: shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1)) if query_padding_mask is not None: shape_utils.assert_shapes_equal( query_padding_mask.shape, (batch_size, query_sequence_length, 1, 1)) if attn_mask is not None: shape_utils.assert_shapes_compatible( attn_mask.shape, (batch_size, 1, query_sequence_length, key_sequence_length)) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis, ) for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(onp.delete(range(n), axis + (n - 1, ))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1, ) key = key.transpose(qk_perm) shape_utils.assert_shapes_equal( key.shape, (batch_size, num_heads, key_sequence_length, channel_size)) key_padding_mask_transposed = None query_padding_mask_transposed = None if key_padding_mask is not None: key_padding_mask_transposed = key_padding_mask.transpose(qk_perm) shape_utils.assert_shapes_equal( key_padding_mask_transposed.shape, (batch_size, 1, key_sequence_length, 1)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_k', update_stats=train)( key, mask=key_padding_mask_transposed) if query_padding_mask is not None: query_padding_mask_transposed = query_padding_mask.transpose(qk_perm) shape_utils.assert_shapes_equal( query_padding_mask_transposed.shape, (batch_size, 1, query_sequence_length, 1)) key_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=key_padding_mask_transposed, module_name='K') # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1, ) + axis value = value.transpose(v_perm) shape_utils.assert_shapes_equal( value.shape, (batch_size, num_heads, channel_size, key_sequence_length)) value_padding_mask_transposed = None if key_padding_mask is not None: value_padding_mask_transposed = key_padding_mask.transpose(v_perm) shape_utils.assert_shapes_equal( value_padding_mask_transposed.shape, (batch_size, 1, 1, key_sequence_length)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_v', update_stats=train)( value, mask=value_padding_mask_transposed) value_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=value_padding_mask_transposed, module_name='V') query = query / jnp.sqrt(depth).astype(dtype) query = query.transpose(qk_perm) shape_utils.assert_shapes_equal( query.shape, (batch_size, num_heads, query_sequence_length, channel_size)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_q', update_stats=train)( query, mask=query_padding_mask_transposed) query_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=query_padding_mask_transposed, module_name='Q') batch_dims_t = tuple(range(len(batch_dims))) attn_weights = quantized_dynamic_dot_general( lhs_act=query, rhs_act=key, dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t, batch_dims_t)), dot_precision=precision, quant_type=hparams.quant_type, lhs_act_hparams=hparams.attn_act_q, lhs_get_bounds_params=query_get_bounds_params, rhs_act_hparams=hparams.attn_act_k, rhs_get_bounds_params=key_get_bounds_params, ) # NOTE(shivaniagrawal): we do per-layer quantization here since that's the # only way for activation*activation matmuls to be aqt compatible since we use # static scaling factors for activations. shape_utils.assert_shapes_equal( attn_weights.shape, (batch_size, num_heads, query_sequence_length, key_sequence_length)) # apply attention bias: masking, dropout, proximity bias, ect. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) attn_weights = softmax(attn_weights, norm_dims, dtype, hparams.softmax, quant_context=quant_context) # apply dropout if not deterministic and dropout_rate > 0.0: if dropout_rng is None: raise ValueError( 'dropout_rng cannot be None if dropout is requested.') keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_probs', update_stats=train)(attn_weights, mask=attn_mask) if hparams.attn_act_probs is not None: assert hparams.attn_act_probs.bounds == 1.0, ( 'act quantization bounds should ' 'be set to fix value 1.0 to ' 'match Softmax range.') probs_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=attn_mask, module_name='attn_probs') # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = quantized_dynamic_dot_general( lhs_act=attn_weights, rhs_act=value, dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t, batch_dims_t)), dot_precision=precision, quant_type=hparams.quant_type, lhs_act_hparams=hparams.attn_act_probs, lhs_get_bounds_params=probs_get_bounds_params, rhs_act_hparams=hparams.attn_act_v, rhs_get_bounds_params=value_get_bounds_params, ) # NOTE(shivaniagrawal): we do per-layer quantization here since that's the # only way for activation*activation matmuls to be aqt compatible since we # use static scaling factors for activations. shape_utils.assert_shapes_equal( y.shape, (batch_size, num_heads, query_sequence_length, channel_size)) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) shape_utils.assert_shapes_equal( y.shape, (batch_size, query_sequence_length, num_heads, channel_size)) return y
def create_updated_stats(cls, stats, samples, *, axis=None, paxis_name=None, alpha=None, mask=None, exclude_zeros=False): """Create a new Stats instance that represents the updated statistics. Since flax.struct.dataclass objects are frozen, this method creates a new instance of Stats with updated stats and returns it. Args: stats: A Stats dataclass object to be updated. samples: An array to update the current statistics with. axis: axis to average input samples over, e.g. to calculate stats per channel. paxis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names. alpha: Smoothing parameter to use for moving average. If None, will use 1/n, where n is the stat count. mask: Optional boolean tensor of the same shape as 'samples' specifying which values of 'samples' to use as part of the bounds calculation. 'True' indicates the corresponding value from 'samples' should be used. If None, all values are used. exclude_zeros: Whether to exclude zeros in samples when computing statistics, e.g. when calculating mean absolute values. Returns: A new Stats instance with updated stats and count. """ if mask is None: mask = jnp.full(samples.shape, True) shape_utils.assert_shapes_compatible(samples.shape, mask.shape) mask = jnp.broadcast_to(mask, samples.shape) if exclude_zeros: # Where samples are zero, set mask to False. This way they won't be # included in statistics. mask = mask & (samples != 0) def _moving_avg(old_avg, new_val, masked_reduction_fn): masked_new_val_reduced = masked_reduction_fn(new_val, mask=mask, axis=axis, paxis_name=paxis_name, keepdims=True) valid_mask = jnp.isfinite(masked_new_val_reduced) # Only update average where means are valid, so set deltas corresponding # to invalid entries to 0. delta = jnp.where(valid_mask, masked_new_val_reduced - old_avg, 0) # TODO(lew): This is slightly incorrect, alpha should be proportional to # the mask size. new_avg = old_avg + alpha * delta return new_avg new_n = stats.n + 1 if alpha is None: alpha = 1. / new_n new_mean = _moving_avg(stats.mean, samples, masked_reduction_fn=masked_mean) new_mean_abs = _moving_avg(stats.mean_abs, jnp.abs(samples), masked_reduction_fn=masked_mean) new_mean_sq = _moving_avg(stats.mean_sq, jnp.square(samples), masked_reduction_fn=masked_mean) new_mean_batch_minimum = _moving_avg( stats.mean_batch_minimum, samples, masked_reduction_fn=masked_mean_of_min) new_mean_batch_maximum = _moving_avg( stats.mean_batch_maximum, samples, masked_reduction_fn=masked_mean_of_max) return cls(n=new_n, mean=new_mean, mean_abs=new_mean_abs, mean_sq=new_mean_sq, mean_batch_minimum=new_mean_batch_minimum, mean_batch_maximum=new_mean_batch_maximum)