def quantized_layernorm(x):
            prec = hparams.quant_hparams.prec
            fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
            quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant,
                                                     bounds=None)

            def to_quantized(x):
                return quant_ops.to_quantized(x, dtype=dtype)

            # If epsilon is too small to represent in the quantized format, we set it
            # to the minimal representative non-zero value to avoid the possibility of
            # dividing by zero.
            fp_bounds = quantization.fp_cast.get_bounds(
                prec.exp_min, prec.exp_max, prec.sig_bits)
            epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound)
            quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype))

            # If the reciprocal of the quantized number of features is too small to
            # represent in the quantized format, we set it to the minimal
            # representative nonzero value so that the mean and variance are not
            # trivially 0.
            num_features_quantized = to_quantized(
                jnp.array(num_features, dtype=dtype))
            num_features_recip_quantized = to_quantized(
                jnp.reciprocal(num_features_quantized))
            num_features_recip_quantized = jax.lax.cond(
                jax.lax.eq(num_features_recip_quantized,
                           0.0), lambda _: quantized_epsilon,
                lambda _: num_features_recip_quantized, None)

            x_quantized = to_quantized(x)
            x_sum_quantized_reduction = quantization.quantized_sum(
                x_quantized,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sum = to_quantized(x_sum_quantized_reduction)
            mean = to_quantized(x_sum * num_features_recip_quantized)
            x_minus_mean = to_quantized(x - mean)
            x_sq = to_quantized(lax.square(x_minus_mean))
            x_sq_sum_quantized_reduction = quantization.quantized_sum(
                x_sq,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sq_sum = to_quantized(x_sq_sum_quantized_reduction)
            var = to_quantized(x_sq_sum * num_features_recip_quantized)
            # Prevent division by zero.
            var_plus_epsilon = to_quantized(var + quantized_epsilon)
            mul = to_quantized(lax.rsqrt(var_plus_epsilon))
            if self.use_scale:
                quantized_scale_param = to_quantized(scale_param)
                mul = to_quantized(mul * quantized_scale_param)
            y = to_quantized(x_minus_mean * mul)
            if self.use_bias:
                quantized_bias_param = to_quantized(bias_param)
                y = to_quantized(y + quantized_bias_param)
            return y.astype(self.dtype)
Example #2
0
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)
 def test_keepdims_and_axis(self, keepdims, axis, expected_shape):
     x = jnp.arange(6).reshape((3, 2)).astype(jnp.float32)
     prec = QuantOps.FloatQuant.FloatPrec(-2**7, 2**7, 23)
     x_quantized_sum = quantization.quantized_sum(x,
                                                  keepdims=keepdims,
                                                  axis=axis,
                                                  prec=prec)
     self.assertEqual(x_quantized_sum.shape, expected_shape)