def test_lax_dot_has_integer_inputs_in_dynamic_dot_general( self, mock_dot_general, lhs_distribution, rhs_distribution): lhs_params = QuantOps.ActHParams(input_distribution=lhs_distribution, bounds=2.0, prec=4) rhs_params = QuantOps.ActHParams(input_distribution=rhs_distribution, bounds=1.5, prec=4) lhs_act = self.lhs if lhs_distribution == 'positive': lhs_act = jnp.abs(lhs_act) rhs_act = self.rhs if rhs_distribution == 'positive': rhs_act = jnp.abs(rhs_act) quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt) lhs_inputs, rhs_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(lhs_inputs, prec=4, distribution=lhs_distribution) self.assert_is_integer_in_range(rhs_inputs, prec=4, distribution=rhs_distribution)
def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec): get_bounds_hyper = get_bounds.GetBounds.Hyper( initial_bound=10.0, stddev_coeff=0, absdev_coeff=0, mix_coeff=0, granularity=quant_config.QuantGranularity.per_tensor) lhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=lhs_prec, half_shift=False) rhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=rhs_prec, half_shift=False) lhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='lhs') rhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='rhs') output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_act_hparams, rhs_act_hparams=rhs_act_hparams, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt, lhs_get_bounds_params=lhs_get_bounds_params, rhs_get_bounds_params=rhs_get_bounds_params) return output
def test_quantized_dynamic_dot_general_should_call_inputs_quantization( self, mock_act_fq, lhs_act_prec, rhs_act_prec, strategy=QuantType.fake_quant): mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs # pylint: disable=g-long-ternary lhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=lhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if lhs_act_prec else None rhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=rhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if rhs_act_prec else None # pylint: enable=g-long-ternary get_bounds_params = GetBounds.Params(update_stats=False, update_bounds=False) quantization.quantized_dynamic_dot_general( lhs_act=self.lhs_act, rhs_act=self.rhs_act, quant_type=strategy, dot_dimension_numbers=self.dimension_numbers, lhs_act_hparams=lhs_act_hparams, lhs_get_bounds_params=get_bounds_params, rhs_act_hparams=rhs_act_hparams, rhs_get_bounds_params=get_bounds_params, ) calls = [] for prec in [lhs_act_prec, rhs_act_prec]: if prec is not None: act_hparams = QuantOps.ActHParams(bounds=6., prec=prec, input_distribution=mock.ANY, half_shift=False) calls.append( mock.call(mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)) self.assertLen(calls, mock_act_fq.call_count) mock_act_fq.assert_has_calls(calls, any_order=True)
def test_dynamic_quantized_dot_general_raises_with_mixed_dtype(self): lhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=2.0, prec=4, half_shift=False) rhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=1.5, prec=4, half_shift=False) lhs_act = self.lhs.astype(jnp.bfloat16) rhs_act = self.rhs.astype(jnp.float32) with self.assertRaises(TypeError): quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1,), (0,)), ((), ())), quant_type=QuantType.aqt)
def quantized_matmul(quant_type): return quantization.quantized_dynamic_dot_general( lhs_act=self.lhs, rhs_act=self.rhs, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=quant_type)
def test_quantized_dynamic_dot_general_no_quant(self): act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=-1.0, prec=4, half_shift=False) lhs_act = jnp.array([[-5.0]]) rhs_act = jnp.array([[-4.99]]) res = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, quant_type=quantization.QuantType.aqt, lhs_act_hparams=act_hparams, rhs_act_hparams=act_hparams, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1,), (0,)), ((), ()))) onp.testing.assert_allclose(res, lhs_act * rhs_act)
def test_dynamic_quantized_dot_general_has_correct_dtype( self, input_dtype, act_prec, quant_type): lhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=2.0, prec=act_prec) rhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=1.5, prec=act_prec) lhs_act = self.lhs.astype(input_dtype) rhs_act = self.rhs.astype(input_dtype) output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=quant_type) self.assertEqual(output.dtype, input_dtype)
def dot_product_attention(query, key, value, hparams, quant_context, paxis_name, train, key_padding_mask, query_padding_mask, attn_mask, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. Args: query: queries for calculating attention with shape of `[batch_size, sequence_length, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, sequence_length, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, sequence_length, num_heads, value_channels]`. hparams: hyperparameters used for quantization. quant_context: context for quantization. paxis_name: axis_name to which a user `pmaps` the parent module (model), refer to jax.pmap() for more documentation. This arg is used for get_bounds acts quantization (QuantOps.create_input_fake_quant) train: Whether model is training. key_padding_mask: boolean mask indicating which elements in 'key' and 'value' are padding. Must have a shape compatible with 'key' and 'value'. query_padding_mask: boolean mask indicating which elements in `query` are padding (True means not padding). attn_mask: boolean mask indicating which elements of the calculated attention weight matrix should be used for collecting activation statistics. Should have a shape broadcast-compatible with '[bs, sequence_length, sequence_length]'. Must have a shape broadcast-compatible 'query'. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout. dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, sequence_length, num_heads, value_channels]`. """ batch_size, query_sequence_length, num_heads, channel_size = query.shape key_sequence_length = key.shape[1] shape_utils.assert_shapes_equal( key.shape, (batch_size, key_sequence_length, num_heads, channel_size)) shape_utils.assert_shapes_equal( value.shape, (batch_size, key_sequence_length, num_heads, channel_size)) if key_padding_mask is not None: shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1)) if query_padding_mask is not None: shape_utils.assert_shapes_equal( query_padding_mask.shape, (batch_size, query_sequence_length, 1, 1)) if attn_mask is not None: shape_utils.assert_shapes_compatible( attn_mask.shape, (batch_size, 1, query_sequence_length, key_sequence_length)) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis, ) for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(onp.delete(range(n), axis + (n - 1, ))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1, ) key = key.transpose(qk_perm) shape_utils.assert_shapes_equal( key.shape, (batch_size, num_heads, key_sequence_length, channel_size)) key_padding_mask_transposed = None query_padding_mask_transposed = None if key_padding_mask is not None: key_padding_mask_transposed = key_padding_mask.transpose(qk_perm) shape_utils.assert_shapes_equal( key_padding_mask_transposed.shape, (batch_size, 1, key_sequence_length, 1)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_k', update_stats=train)( key, mask=key_padding_mask_transposed) if query_padding_mask is not None: query_padding_mask_transposed = query_padding_mask.transpose(qk_perm) shape_utils.assert_shapes_equal( query_padding_mask_transposed.shape, (batch_size, 1, query_sequence_length, 1)) key_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=key_padding_mask_transposed, module_name='K') # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1, ) + axis value = value.transpose(v_perm) shape_utils.assert_shapes_equal( value.shape, (batch_size, num_heads, channel_size, key_sequence_length)) value_padding_mask_transposed = None if key_padding_mask is not None: value_padding_mask_transposed = key_padding_mask.transpose(v_perm) shape_utils.assert_shapes_equal( value_padding_mask_transposed.shape, (batch_size, 1, 1, key_sequence_length)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_v', update_stats=train)( value, mask=value_padding_mask_transposed) value_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=value_padding_mask_transposed, module_name='V') query = query / jnp.sqrt(depth).astype(dtype) query = query.transpose(qk_perm) shape_utils.assert_shapes_equal( query.shape, (batch_size, num_heads, query_sequence_length, channel_size)) if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_q', update_stats=train)( query, mask=query_padding_mask_transposed) query_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=query_padding_mask_transposed, module_name='Q') batch_dims_t = tuple(range(len(batch_dims))) attn_weights = quantized_dynamic_dot_general( lhs_act=query, rhs_act=key, dot_dimension_numbers=(((n - 1, ), (n - 1, )), (batch_dims_t, batch_dims_t)), dot_precision=precision, quant_type=hparams.quant_type, lhs_act_hparams=hparams.attn_act_q, lhs_get_bounds_params=query_get_bounds_params, rhs_act_hparams=hparams.attn_act_k, rhs_get_bounds_params=key_get_bounds_params, ) # NOTE(shivaniagrawal): we do per-layer quantization here since that's the # only way for activation*activation matmuls to be aqt compatible since we use # static scaling factors for activations. shape_utils.assert_shapes_equal( attn_weights.shape, (batch_size, num_heads, query_sequence_length, key_sequence_length)) # apply attention bias: masking, dropout, proximity bias, ect. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) attn_weights = softmax(attn_weights, norm_dims, dtype, hparams.softmax, quant_context=quant_context) # apply dropout if not deterministic and dropout_rate > 0.0: if dropout_rng is None: raise ValueError( 'dropout_rng cannot be None if dropout is requested.') keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier if quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=None, name='attn_act_probs', update_stats=train)(attn_weights, mask=attn_mask) if hparams.attn_act_probs is not None: assert hparams.attn_act_probs.bounds == 1.0, ( 'act quantization bounds should ' 'be set to fix value 1.0 to ' 'match Softmax range.') probs_get_bounds_params = get_bounds.GetBounds.Params( update_bounds=quant_context.update_bounds, update_stats=train, paxis_name=paxis_name, mask=attn_mask, module_name='attn_probs') # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = quantized_dynamic_dot_general( lhs_act=attn_weights, rhs_act=value, dot_dimension_numbers=(wv_contracting_dims, (batch_dims_t, batch_dims_t)), dot_precision=precision, quant_type=hparams.quant_type, lhs_act_hparams=hparams.attn_act_probs, lhs_get_bounds_params=probs_get_bounds_params, rhs_act_hparams=hparams.attn_act_v, rhs_get_bounds_params=value_get_bounds_params, ) # NOTE(shivaniagrawal): we do per-layer quantization here since that's the # only way for activation*activation matmuls to be aqt compatible since we # use static scaling factors for activations. shape_utils.assert_shapes_equal( y.shape, (batch_size, num_heads, query_sequence_length, channel_size)) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) shape_utils.assert_shapes_equal( y.shape, (batch_size, query_sequence_length, num_heads, channel_size)) return y