def __call__( self, inputs, *, padding_mask, ): """Applies a linear transformation to the inputs with optional quantization. If weight_prec is not None, scales and quantizes weights to signed int with weight_prec bits. Args: inputs: The nd-array to be transformed. padding_mask: boolean tensor of the same shape as 'inputs' specifying which values of 'inputs' to use as part of the bounds calculation. 'True' indicates the corresponding value from 'inputs' should be used. If None, all values are used. Returns: The transformed input. """ batch_size = inputs.shape[0] if padding_mask is not None: shape_utils.assert_shapes_equal(padding_mask.shape, (batch_size, 1)) # TODO(wanglisa): Replace fake quant with AQT. if self.quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=-1, name='inputs', update_stats=self.train)(inputs, mask=padding_mask) hparams = self.hparams if (hparams.weight_prec is not None and isinstance(hparams.weight_prec, int) and hparams.weight_prec > 8): raise NotImplementedError( 'If you want to use more than 8bits for quantization, please revisit ' 'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.' ) kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(kernel, self.dtype) get_bounds_params = get_bounds.GetBounds.Params( update_bounds=self.quant_context.update_bounds, update_stats=self.train, paxis_name=self.paxis_name, mask=padding_mask) weight_quant_granularity = hparams.weight_quant_granularity # kernel.shape = (channels_in, channels_out) if weight_quant_granularity == quant_config.QuantGranularity.per_channel: # Compute scale factors by reducing over the rows of the weight matrix, # resulting in one scale factor per column. This results in one scale # factor per output channel. expected_scale_shape = (1, self.features) weight_quant_axis = (0, ) elif weight_quant_granularity == quant_config.QuantGranularity.per_tensor: # Compute a single scale factor for the entire weight matrix. expected_scale_shape = (1, 1) weight_quant_axis = None else: raise ValueError( f'Invalid quantization granularity {weight_quant_granularity}.' ) weight_params = QuantOps.WeightParams( prec=hparams.weight_prec, half_shift=hparams.weight_half_shift, axis=weight_quant_axis, expected_scale_shape=expected_scale_shape) # TODO(wanglisa): add option to control when scale is being recomputed # matmul contracting_dims = ((inputs.ndim - 1, ), (0, )) # `((lhs_contracting_dims, rhs_contracting_dims), batch_dims = ((), ()) # (lhs_batch_dims, rhs_batch_dims))` y = quantization.quantized_dot_general( act=inputs, w=kernel, quant_type=hparams.quant_type, weight_params=weight_params, act_hparams=hparams.quant_act, get_bounds_params=get_bounds_params, dimension_numbers=(contracting_dims, batch_dims), dot_precision=self.precision, prefer_int8_to_int32_dot=self.quant_context. prefer_int8_to_int32_dot) # bias if self.use_bias: bias = self.param('bias', self.bias_init, (self.features, )) # (batch_size, features) y = y + bias[jnp.newaxis, :] return y
def dot_product_attention(query, key, value, hparams, quant_context, paxis_name, train, key_padding_mask, query_padding_mask, attn_mask, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. Args: query: queries for calculating attention with shape of `[batch_size, sequence_length, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, sequence_length, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, sequence_length, num_heads, value_channels]`. hparams: hyperparameters used for quantization. quant_context: context for quantization. paxis_name: axis_name to which a user `pmaps` the parent module (model), refer to jax.pmap() for more documentation. This arg is used for get_bounds acts quantization (QuantOps.create_input_fake_quant) train: Whether model is training. key_padding_mask: boolean mask indicating which elements in 'key' and 'value' are padding. Must have a shape compatible with 'key' and 'value'. query_padding_mask: boolean mask indicating which elements in `query` are padding (True means not padding). attn_mask: boolean mask indicating which elements of the calculated attention weight matrix should be used for collecting activation statistics. Should have a shape broadcast-compatible with '[bs, sequence_length, sequence_length]'. Must have a shape broadcast-compatible 'query'. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout. dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, sequence_length, num_heads, value_channels]`. """ batch_size, query_sequence_length, num_heads, channel_size = query.shape key_sequence_length = key.shape[1] shape_utils.assert_shapes_equal( key.shape, (batch_size, key_sequence_length, num_heads, channel_size)) shape_utils.assert_shapes_equal( value.shape, (batch_size, key_sequence_length, num_heads, channel_size)) if key_padding_mask is not None: shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1)) if query_padding_mask is not None: shape_utils.assert_shapes_equal( query_padding_mask.shape, (batch_size, query_sequence_length, 1, 1)) if attn_mask is not None: shape_utils.assert_shapes_compatible( attn_mask.shape, (batch_size, 1, query_sequence_length, key_sequence_length)) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis, ) for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(onp.delete(range(n), axis + (n - 1, ))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1, ) key = key.transpose(qk_perm) shape_utils.assert_shapes_equal( key.shape, (batch_size, num_heads, key_sequence_length, channel_size)) key_padding_mask_transposed = None query_padding_mask_transposed = None if key_padding_mask is not None: key_padding_mask_transposed = key_padding_mask.transpose(qk_perm) shape_utils.assert_shapes_equal( key_padding_mask_transposed.shape, (batch_size, 1, key_sequence_length, 1)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_k', update_stats=train)( key, mask=key_padding_mask_transposed) if query_padding_mask is not None: query_padding_mask_transposed = query_padding_mask.transpose(qk_perm) shape_utils.assert_shapes_equal( query_padding_mask_transposed.shape, (batch_size, 1, query_sequence_length, 1)) key_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=key_padding_mask_transposed, module_name='K') # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1, ) + axis value = value.transpose(v_perm) shape_utils.assert_shapes_equal( value.shape, (batch_size, num_heads, channel_size, key_sequence_length)) value_padding_mask_transposed = None if key_padding_mask is not None: value_padding_mask_transposed = key_padding_mask.transpose(v_perm) shape_utils.assert_shapes_equal( value_padding_mask_transposed.shape, (batch_size, 1, 1, key_sequence_length)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_v', update_stats=train)( value, mask=value_padding_mask_transposed) value_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=value_padding_mask_transposed, module_name='V') query = query / jnp.sqrt(depth).astype(dtype) query = query.transpose(qk_perm) shape_utils.assert_shapes_equal( query.shape, (batch_size, num_heads, query_sequence_length, channel_size)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_q', update_stats=train)( query, mask=query_padding_mask_transposed) query_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=query_padding_mask_transposed, module_name='Q') batch_dims_t = tuple(range(len(batch_dims))) attn_weights = quantized_dynamic_dot_general( lhs_act=query, rhs_act=key, dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t, batch_dims_t)), dot_precision=precision, quant_type=hparams.quant_type, lhs_act_hparams=hparams.attn_act_q, lhs_get_bounds_params=query_get_bounds_params, rhs_act_hparams=hparams.attn_act_k, rhs_get_bounds_params=key_get_bounds_params, ) # NOTE(shivaniagrawal): we do per-layer quantization here since that's the # only way for activation*activation matmuls to be aqt compatible since we use # static scaling factors for activations. shape_utils.assert_shapes_equal( attn_weights.shape, (batch_size, num_heads, query_sequence_length, key_sequence_length)) # apply attention bias: masking, dropout, proximity bias, ect. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) attn_weights = softmax(attn_weights, norm_dims, dtype, hparams.softmax, quant_context=quant_context) # apply dropout if not deterministic and dropout_rate > 0.0: if dropout_rng is None: raise ValueError( 'dropout_rng cannot be None if dropout is requested.') keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_probs', update_stats=train)(attn_weights, mask=attn_mask) if hparams.attn_act_probs is not None: assert hparams.attn_act_probs.bounds == 1.0, ( 'act quantization bounds should ' 'be set to fix value 1.0 to ' 'match Softmax range.') probs_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=attn_mask, module_name='attn_probs') # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = quantized_dynamic_dot_general( lhs_act=attn_weights, rhs_act=value, dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t, batch_dims_t)), dot_precision=precision, quant_type=hparams.quant_type, lhs_act_hparams=hparams.attn_act_probs, lhs_get_bounds_params=probs_get_bounds_params, rhs_act_hparams=hparams.attn_act_v, rhs_get_bounds_params=value_get_bounds_params, ) # NOTE(shivaniagrawal): we do per-layer quantization here since that's the # only way for activation*activation matmuls to be aqt compatible since we # use static scaling factors for activations. shape_utils.assert_shapes_equal( y.shape, (batch_size, num_heads, query_sequence_length, channel_size)) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) shape_utils.assert_shapes_equal( y.shape, (batch_size, query_sequence_length, num_heads, channel_size)) return y