def test_scale_invariance_weight_quantization(self, prec):
        # Scaling weights by power of 2, should scale the output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (10, 1))
        weight_scale = 16
        scaled_weights = weights * weight_scale

        weights = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        scaled_weights = QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
 def test_float_weights_quantization(self, prec):
     # Tests that quantized and rescaled float weights are close to original
     # weights.
     weights = jnp.array(
         fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1))))
     rescaled_weights = QuantOps.create_weights_fake_quant(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=None))
     test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)
    def test_weight_scale_shape_is_expected(self, axis):
        # Tests if scale is as expected for weights quantization.

        num_features = 4
        expected_scale_shape = (1, 1) if axis is None else (1, num_features)

        # Weight Quantization
        weights = jnp.array(
            fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, num_features))))
        _ = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(
                prec=8.0, axis=axis,
                expected_scale_shape=expected_scale_shape))
Ejemplo n.º 4
0
    def __call__(self, inputs):
        """Applies a convolution to the inputs with optional quantization.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).

    Returns:
      The convolved data.
    """
        hparams = self.hparams
        if hparams.weight_prec is not None and hparams.weight_prec > 8:
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )
        jax_precision = jax.lax.Precision.DEFAULT

        if self.strides is None:
            strides = (1, ) * (inputs.ndim - 2)
        else:
            strides = self.strides

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count, self.features)
        kernel = self.param('kernel', self.kernel_init, kernel_shape)

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        # Activation quantization
        if hparams.quant_act is not None:
            inputs = QuantOps.create_inputs_fake_quant(
                inputs=inputs,
                hparams=hparams.quant_act,
                get_bounds_params=get_bounds.GetBounds.Params(
                    update_bounds=self.quant_context.update_bounds,
                    update_stats=self.train,
                    paxis_name=self.paxis_name))

        # Weight quantization
        if hparams.weight_prec is not None:
            kernel_reduction_axis = tuple(range(kernel.ndim - 1))
            expected_scale_shape = (1, ) * (kernel.ndim - 1) + (
                self.features, )
            assert hparams.quant_type == QuantType.fake_quant, (
                'we only support fake_quant style of aqt for ConvAqt.')
            quantized_type = hparams.quant_type.to_jax_type()
            kernel = QuantOps.create_weights_fake_quant(
                kernel,
                weight_params=QuantOps.WeightParams(
                    prec=hparams.weight_prec,
                    half_shift=hparams.weight_half_shift,
                    axis=kernel_reduction_axis,
                    expected_scale_shape=expected_scale_shape),
                quantized_type=quantized_type)

        # Convolution
        dimension_numbers = flax.nn.linear._conv_dimension_numbers(
            inputs.shape)  # pylint: disable=protected-access
        metadata_context = contextlib.suppress()
        # Use metadata context to annotate op metadata with quantization info
        act_prec = None if hparams.quant_act is None else hparams.quant_act.prec

        if flags.FLAGS.metadata_enabled:
            metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch(
                weight_prec=hparams.weight_prec, act_prec=act_prec)
        with metadata_context:
            y = lax.conv_general_dilated(
                inputs,
                kernel,
                strides,
                self.padding,
                lhs_dilation=self.input_dilation,
                rhs_dilation=self.kernel_dilation,
                dimension_numbers=dimension_numbers,
                feature_group_count=self.feature_group_count,
                precision=jax_precision)
        # TODO(shivaniagrawal): create quantized conv general dilated.

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            bias = jnp.asarray(bias, self.dtype)
            # The inputs can have an arbitrary number of spatial dims, so we broadcast
            # the bias to match: (batch_size, spatial_dim,... features)
            # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D)
            # or maybe add error checking (e.g. expect inputs to have rank N, but this
            # may already be checked by lax.conv_general_dialated).
            bias = utils.broadcast_rank(bias, inputs)
            y = y + bias
        return y