def test_attributes_create_weights_ops(self, weight_range, weight_shape, prec): weights = jnp.array( fp32( onp.random.uniform(weight_range[0], weight_range[1], size=weight_shape))) axis = 0 if weight_shape[1] != 1 else None weights_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=axis)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale), (2**(prec - 1) - 1) / max_weight) self.assertEqual(weights_quant._symmetric, True) self.assertEqual(weights_quant._prec, prec)
def test_full_range_int_weight_quantization(self, prec): # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int] # quantizes correctly. minval = -2**(prec - 1) + 1 maxval = 2**(prec - 1) - 1 weights = random.randint(random.PRNGKey(0), (10, 1), minval, maxval + 1) weights = weights.at[0, :].set(maxval) weight_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams( prec=prec, axis=None, half_shift=False)) quantized_weights = weight_quant.to_quantized(weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(quantized_weights[0], (2**(prec - 1.0) - 1.0)) rescaled_weights = weight_quant.from_quantized( quantized_weights, dtype=jnp.float32) onp.testing.assert_array_equal(weights, rescaled_weights)
def test_attributes_create_weights_op_fp( self, weight_range, weight_shape, fp_quant, ): weights = jnp.array( fp32(onp.random.uniform(*weight_range, size=weight_shape))) axis = None if weight_shape[1] == 1 else 0 weights_quant_op = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=fp_quant, axis=axis, half_shift=False)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal( jnp.squeeze(weights_quant_op._scale), jnp.exp2(-jnp.floor(jnp.log2(max_weight)))) self.assertEqual(weights_quant_op._symmetric, True) self.assertIs(weights_quant_op._prec, fp_quant) weights_scaled = (weights * weights_quant_op._scale).astype( weights.dtype) weights_quant_expected = fp_cast.downcast_sat_ftz( weights_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) weights_quant_calculated = weights_quant_op.to_quantized( weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(weights_quant_expected, weights_quant_calculated) # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated # quantized weights are zero. sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1) onp.testing.assert_array_equal( weights_quant_calculated.view(jnp.int32) & sig_mask, jnp.zeros_like(weights))
def __call__( self, inputs, ): """Embeds the inputs along the last dimension. Args: inputs: input data, all dimensions are considered batch dimensions. Returns: Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ batch_size, sequence_length = inputs.shape if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: raise ValueError( 'Input type must be an integer or unsigned integer.') embedding = self.embedding embedding = jnp.asarray(embedding, self.dtype) hparams = self.hparams # Initialize state for stats and bounds, this would be required for logits # in the following method attend. if hparams.quant_act is not None and isinstance( hparams.quant_act.bounds, get_bounds.GetBounds.Hyper): self.get_bounds_logits( inputs, bounds_params=get_bounds.GetBounds.Params(update_stats=False, update_bounds=False, paxis_name=None), ) weight_prec = hparams.weight_prec weight_half_shift = hparams.weight_half_shift if weight_prec is not None: quantized_type = hparams.quant_type.to_jax_type() # In contrast to all other scale factor calculations in this module, we # compute per-row instead of per-column (ie, per-output-channel) scale # factors here. This is because the embedding matrix might be shared with # the output (logit) layer of the transformer, in which case the # *transpose* of the embedding matrix will be used as the weight matrix in # a mamtul. The per-row scale factors used here would thus correspond to # using per-column (because of the transpose) scale factors used by the # weight matrix in the logits layer, which is what we need for AQT. embedding_quant_ops = QuantOps.create_weights_ops( embedding, weight_params=QuantOps.WeightParams( prec=weight_prec, axis=(1, ), half_shift=weight_half_shift)) embedding_quant_ops.assert_scale_shape_is( shape=(self.num_embeddings, 1)) quantized_embedding = embedding_quant_ops.to_quantized( embedding, dtype=quantized_type) quantized_embedded_inputs = quantized_embedding[inputs] # Since the embedding matrix 'quantized_embedding' is gathered based on # 'inputs' to produce the embedding tensor, we apply the same gathering to # the per-row scale factors of the embedding matrix so the scale factors # will broadcast appropriately in the subsequent call to 'to_quantized'. # TODO(malmaud): As part of quantization.py refactor, change # 'get_scale_for_aqt' to cleanly support this and hence avoid the need to # directly access a protected member of QuantOps. scale = embedding_quant_ops._scale[inputs] # pylint: disable=protected-access shape_utils.assert_shapes_equal(scale.shape, (batch_size, sequence_length, 1)) shape_utils.assert_shapes_equal( quantized_embedded_inputs.shape, (batch_size, sequence_length, self.features)) embedded_inputs = (quantized_embedded_inputs / scale).astype( self.dtype) else: embedded_inputs = embedding[inputs] shape_utils.assert_shapes_equal( embedded_inputs.shape, (batch_size, sequence_length, self.features)) return embedded_inputs