コード例 #1
0
    def test_return_type(self):
        x_bf16 = jnp.array(1.0, dtype=jnp.bfloat16)
        y_bf16 = fp_cast.downcast_sat_ftz(x_bf16,
                                          exp_min=-11,
                                          exp_max=4,
                                          sig_bits=3)
        self.assertEqual(x_bf16.dtype, y_bf16.dtype)

        xf32 = jnp.array(1.0, dtype=jnp.bfloat16)
        yf32 = fp_cast.downcast_sat_ftz(xf32,
                                        exp_min=-11,
                                        exp_max=4,
                                        sig_bits=3)
        self.assertEqual(xf32.dtype, yf32.dtype)
コード例 #2
0
    def to_quantized(self, x, *, dtype):
        """Quantizes the argument to the target format.

    integer: "upscales", rounds or floors and clips.
    floating-point: optionally upscales, then downcasts to target precision.

    Args:
      x: Argument to be quantized.
      dtype: Type of returned quantized value of x. If quantized x is an input
        to a matmul, we might be want to set it to jnp.int8. If quantized x is
        weights stored in memory, same applies. In fake_quant style we might
        prefer to set dtype=SCALE_DTYPE, since quantized x might get constant
        folded with rescale op (`from_quantized`). Please take a look at the
        comment on SCALE_DTYPE.

    Returns:
      Quantized value of x.
    """
        if isinstance(self._prec, _FloatQuant):
            if self._prec.is_scaled:
                x = jnp.multiply(x, self._scale).astype(x.dtype)
            fp_spec = self._prec.fp_spec
            return fp_cast.downcast_sat_ftz(
                x,
                fp_spec.exp_min,
                fp_spec.exp_max,
                fp_spec.sig_bits,
            )
        else:
            if self._symmetric:
                quantize = primitives.round_and_clip_to_signed_int
            else:
                quantize = primitives.floor_and_clip_to_unsigned_int
            scaled_x = jnp.multiply(x, self._scale)
            return quantize(scaled_x, prec=self._prec, dtype=dtype)
コード例 #3
0
  def test_attributes_create_acts_op_fp(
      self,
      act_distribution,
      use_hparams_bounds,
  ):
    inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4))))
    fp_quant = QuantOps.FloatQuant(
        is_scaled=True,
        fp_spec=QuantOps.FloatQuant.FloatPrec(
            exp_min=-15,
            exp_max=15,
            sig_bits=2,
        ),
    )
    if use_hparams_bounds:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=1,
          absdev_coeff=0,
          mix_coeff=1,
          reset_stats=True,
          ema_coeff=None,
          use_cams=False,
          granularity=quant_config.QuantGranularity.per_tensor)
    else:
      bounds = 6.0

    hparams = QuantOps.ActHParams(
        input_distribution=act_distribution, bounds=bounds, prec=fp_quant,
        half_shift=False)

    class TestModule(nn.Module):
      hparams: QuantOps.ActHParams

      @nn.compact
      def __call__(self, inputs):
        return QuantOps.create_input_ops(
            inputs,
            hparams=hparams,
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False))

    test_module = TestModule(hparams=hparams)
    state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
    act_quant_op = test_module.apply(state, inputs=inputs)

    act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype)
    act_quant_expected = fp_cast.downcast_sat_ftz(
        act_scaled,
        fp_quant.fp_spec.exp_min,
        fp_quant.fp_spec.exp_max,
        fp_quant.fp_spec.sig_bits,
    )
    act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
コード例 #4
0
 def test_downcast_sat_ftz(self, dtype, argument_result_values):
     argument_result = jnp.array(
         argument_result_values,
         dtype=dtype,
     )
     y = fp_cast.downcast_sat_ftz(
         argument_result[:, 0],
         exp_min=-11,
         exp_max=4,
         sig_bits=3,
     )
     onp.testing.assert_equal(
         onp.array(argument_result[:, 1], dtype=onp.float32),
         onp.array(y, dtype=onp.float32),
     )
コード例 #5
0
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
コード例 #6
0
 def downcast_and_sum(x):
     return jnp.sum(
         fp_cast.downcast_sat_ftz(x,
                                  sig_bits=sig_bits,
                                  exp_min=exp_min,
                                  exp_max=exp_max))
コード例 #7
0
 def test_sig_bits_zero(self):
     x = jnp.array(2.11111)
     y = fp_cast.downcast_sat_ftz(x, exp_min=-11, exp_max=4, sig_bits=0)
     self.assertEqual(y.item(), 2.0)
コード例 #8
0
 def test_invalid_argument_type(self):
     x_s8 = jnp.array(1, dtype=jnp.int8)
     with self.assertRaises(ValueError):
         fp_cast.downcast_sat_ftz(x_s8, exp_min=-11, exp_max=4, sig_bits=3)