def test_attributes_create_weights_ops(self, weight_range, weight_shape,
                                        prec):
     weights = jnp.array(
         fp32(
             onp.random.uniform(weight_range[0],
                                weight_range[1],
                                size=weight_shape)))
     axis = 0 if weight_shape[1] != 1 else None
     weights_quant = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=axis))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale),
                                    (2**(prec - 1) - 1) / max_weight)
     self.assertEqual(weights_quant._symmetric, True)
     self.assertEqual(weights_quant._prec, prec)
 def test_full_range_int_weight_quantization(self, prec):
   # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int]
   # quantizes correctly.
   minval = -2**(prec - 1) + 1
   maxval = 2**(prec - 1) - 1
   weights = random.randint(random.PRNGKey(0), (10, 1), minval, maxval + 1)
   weights = weights.at[0, :].set(maxval)
   weight_quant = QuantOps.create_weights_ops(
       w=weights,
       weight_params=QuantOps.WeightParams(
           prec=prec, axis=None, half_shift=False))
   quantized_weights = weight_quant.to_quantized(weights, dtype=SCALE_DTYPE)
   onp.testing.assert_array_equal(quantized_weights[0],
                                  (2**(prec - 1.0) - 1.0))
   rescaled_weights = weight_quant.from_quantized(
       quantized_weights, dtype=jnp.float32)
   onp.testing.assert_array_equal(weights, rescaled_weights)
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
    def __call__(
        self,
        inputs,
    ):
        """Embeds the inputs along the last dimension.

    Args:
      inputs: input data, all dimensions are considered batch dimensions.

    Returns:
      Output which is embedded input data.  The output shape follows the input,
      with an additional `features` dimension appended.
    """
        batch_size, sequence_length = inputs.shape
        if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
            raise ValueError(
                'Input type must be an integer or unsigned integer.')
        embedding = self.embedding

        embedding = jnp.asarray(embedding, self.dtype)

        hparams = self.hparams
        # Initialize state for stats and bounds, this would be required for logits
        # in the following method attend.
        if hparams.quant_act is not None and isinstance(
                hparams.quant_act.bounds, get_bounds.GetBounds.Hyper):
            self.get_bounds_logits(
                inputs,
                bounds_params=get_bounds.GetBounds.Params(update_stats=False,
                                                          update_bounds=False,
                                                          paxis_name=None),
            )

        weight_prec = hparams.weight_prec
        weight_half_shift = hparams.weight_half_shift
        if weight_prec is not None:
            quantized_type = hparams.quant_type.to_jax_type()
            # In contrast to all other scale factor calculations in this module, we
            # compute per-row instead of per-column (ie, per-output-channel) scale
            # factors here. This is because the embedding matrix might be shared with
            # the output (logit) layer of the transformer, in which case the
            # *transpose* of the embedding matrix will be used as the weight matrix in
            # a mamtul. The per-row scale factors used here would thus correspond to
            # using per-column (because of the transpose) scale factors used by the
            # weight matrix in the logits layer, which is what we need for AQT.
            embedding_quant_ops = QuantOps.create_weights_ops(
                embedding,
                weight_params=QuantOps.WeightParams(
                    prec=weight_prec, axis=(1, ),
                    half_shift=weight_half_shift))
            embedding_quant_ops.assert_scale_shape_is(
                shape=(self.num_embeddings, 1))

            quantized_embedding = embedding_quant_ops.to_quantized(
                embedding, dtype=quantized_type)
            quantized_embedded_inputs = quantized_embedding[inputs]
            # Since the embedding matrix 'quantized_embedding' is gathered based on
            # 'inputs' to produce the embedding tensor, we apply the same gathering to
            # the per-row scale factors of the embedding matrix so the scale factors
            # will broadcast appropriately in the subsequent call to 'to_quantized'.
            # TODO(malmaud): As part of quantization.py refactor, change
            # 'get_scale_for_aqt' to cleanly support this and hence avoid the need to
            # directly access a protected member of QuantOps.
            scale = embedding_quant_ops._scale[inputs]  # pylint: disable=protected-access
            shape_utils.assert_shapes_equal(scale.shape,
                                            (batch_size, sequence_length, 1))
            shape_utils.assert_shapes_equal(
                quantized_embedded_inputs.shape,
                (batch_size, sequence_length, self.features))
            embedded_inputs = (quantized_embedded_inputs / scale).astype(
                self.dtype)
        else:
            embedded_inputs = embedding[inputs]
        shape_utils.assert_shapes_equal(
            embedded_inputs.shape,
            (batch_size, sequence_length, self.features))
        return embedded_inputs