def test_return_type(self):
    x_bf16 = jnp.array(1.0, dtype=jnp.bfloat16)
    y_bf16 = fp_cast.downcast_sat_ftz(
        x_bf16, exp_min=-11, exp_max=4, sig_bits=3)
    self.assertEqual(x_bf16.dtype, y_bf16.dtype)

    xf32 = jnp.array(1.0, dtype=jnp.bfloat16)
    yf32 = fp_cast.downcast_sat_ftz(
        xf32, exp_min=-11, exp_max=4, sig_bits=3)
    self.assertEqual(xf32.dtype, yf32.dtype)
示例#2
0
    def to_quantized(self, x, *, dtype):
        """Quantizes the argument to the target format.

    integer: "upscales", rounds or floors and clips.
    floating-point: optionally upscales, then downcasts to target precision.

    Args:
      x: Argument to be quantized.
      dtype: Type of returned quantized value of x. If quantized x is an input
        to a matmul, we might be want to set it to jnp.int8. If quantized x is
        weights stored in memory, same applies. In fake_quant style we might
        prefer to set dtype=SCALE_DTYPE, since quantized x might get constant
        folded with rescale op (`from_quantized`). Please take a look at the
        comment on SCALE_DTYPE.

    Returns:
      Quantized value of x.
    """
        if isinstance(self._prec, _FloatQuant):
            if self._prec.is_scaled:
                x = jnp.multiply(x, self._scale).astype(x.dtype)
            fp_spec = self._prec.fp_spec
            return fp_cast.downcast_sat_ftz(
                x,
                fp_spec.exp_min,
                fp_spec.exp_max,
                fp_spec.sig_bits,
            )
        else:
            if self._symmetric:
                quantize = primitives.round_and_clip_to_signed_int
            else:
                quantize = primitives.floor_and_clip_to_unsigned_int
            scaled_x = jnp.multiply(x, self._scale)
            return quantize(scaled_x, prec=self._prec, dtype=dtype)
 def test_downcast_sat_ftz(self, dtype, argument_result_values):
     argument_result = jnp.array(
         argument_result_values,
         dtype=dtype,
     )
     y = fp_cast.downcast_sat_ftz(
         argument_result[:, 0],
         exp_min=-11,
         exp_max=4,
         sig_bits=3,
     )
     onp.testing.assert_equal(
         onp.array(argument_result[:, 1], dtype=onp.float32),
         onp.array(y, dtype=onp.float32),
     )
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant, axis=axis))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
 def test_invalid_argument_type(self):
     x_s8 = jnp.array(1, dtype=jnp.int8)
     with self.assertRaises(ValueError):
         fp_cast.downcast_sat_ftz(x_s8, exp_min=-11, exp_max=4, sig_bits=3)
示例#6
0
 def downcast_and_sum(x):
     return jnp.sum(
         fp_cast.downcast_sat_ftz(x,
                                  sig_bits=sig_bits,
                                  exp_min=exp_min,
                                  exp_max=exp_max))
示例#7
0
 def test_sig_bits_zero(self):
     x = jnp.array(2.11111)
     y = fp_cast.downcast_sat_ftz(x, exp_min=-11, exp_max=4, sig_bits=0)
     self.assertEqual(y.item(), 2.0)