def test_return_type(self): x_bf16 = jnp.array(1.0, dtype=jnp.bfloat16) y_bf16 = fp_cast.downcast_sat_ftz( x_bf16, exp_min=-11, exp_max=4, sig_bits=3) self.assertEqual(x_bf16.dtype, y_bf16.dtype) xf32 = jnp.array(1.0, dtype=jnp.bfloat16) yf32 = fp_cast.downcast_sat_ftz( xf32, exp_min=-11, exp_max=4, sig_bits=3) self.assertEqual(xf32.dtype, yf32.dtype)
def to_quantized(self, x, *, dtype): """Quantizes the argument to the target format. integer: "upscales", rounds or floors and clips. floating-point: optionally upscales, then downcasts to target precision. Args: x: Argument to be quantized. dtype: Type of returned quantized value of x. If quantized x is an input to a matmul, we might be want to set it to jnp.int8. If quantized x is weights stored in memory, same applies. In fake_quant style we might prefer to set dtype=SCALE_DTYPE, since quantized x might get constant folded with rescale op (`from_quantized`). Please take a look at the comment on SCALE_DTYPE. Returns: Quantized value of x. """ if isinstance(self._prec, _FloatQuant): if self._prec.is_scaled: x = jnp.multiply(x, self._scale).astype(x.dtype) fp_spec = self._prec.fp_spec return fp_cast.downcast_sat_ftz( x, fp_spec.exp_min, fp_spec.exp_max, fp_spec.sig_bits, ) else: if self._symmetric: quantize = primitives.round_and_clip_to_signed_int else: quantize = primitives.floor_and_clip_to_unsigned_int scaled_x = jnp.multiply(x, self._scale) return quantize(scaled_x, prec=self._prec, dtype=dtype)
def test_downcast_sat_ftz(self, dtype, argument_result_values): argument_result = jnp.array( argument_result_values, dtype=dtype, ) y = fp_cast.downcast_sat_ftz( argument_result[:, 0], exp_min=-11, exp_max=4, sig_bits=3, ) onp.testing.assert_equal( onp.array(argument_result[:, 1], dtype=onp.float32), onp.array(y, dtype=onp.float32), )
def test_attributes_create_weights_op_fp( self, weight_range, weight_shape, fp_quant, ): weights = jnp.array( fp32(onp.random.uniform(*weight_range, size=weight_shape))) axis = None if weight_shape[1] == 1 else 0 weights_quant_op = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=fp_quant, axis=axis)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal( jnp.squeeze(weights_quant_op._scale), jnp.exp2(-jnp.floor(jnp.log2(max_weight)))) self.assertEqual(weights_quant_op._symmetric, True) self.assertIs(weights_quant_op._prec, fp_quant) weights_scaled = (weights * weights_quant_op._scale).astype( weights.dtype) weights_quant_expected = fp_cast.downcast_sat_ftz( weights_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) weights_quant_calculated = weights_quant_op.to_quantized( weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(weights_quant_expected, weights_quant_calculated) # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated # quantized weights are zero. sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1) onp.testing.assert_array_equal( weights_quant_calculated.view(jnp.int32) & sig_mask, jnp.zeros_like(weights))
def test_invalid_argument_type(self): x_s8 = jnp.array(1, dtype=jnp.int8) with self.assertRaises(ValueError): fp_cast.downcast_sat_ftz(x_s8, exp_min=-11, exp_max=4, sig_bits=3)
def downcast_and_sum(x): return jnp.sum( fp_cast.downcast_sat_ftz(x, sig_bits=sig_bits, exp_min=exp_min, exp_max=exp_max))
def test_sig_bits_zero(self): x = jnp.array(2.11111) y = fp_cast.downcast_sat_ftz(x, exp_min=-11, exp_max=4, sig_bits=0) self.assertEqual(y.item(), 2.0)