def test_check_value_inside_and_outside_of_context_conv_general(
            self, weight_prec):
        original_op_name = 'conv_general_dilated'
        # The 'name' in primitive should change in the context in 'flax_layers'
        # if the context is enabled
        self.assertEqual(original_op_name, lax.conv_general_dilated_p.name)

        with compute_cost_utils.ConvMetadataMonkeyPatch(
                weight_prec=weight_prec, act_prec=None):
            self.assertNotEqual(original_op_name,
                                lax.conv_general_dilated_p.name)
        self.assertEqual(original_op_name, lax.conv_general_dilated_p.name)
Пример #2
0
    def __call__(self, inputs):
        """Applies a convolution to the inputs with optional quantization.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).

    Returns:
      The convolved data.
    """
        hparams = self.hparams
        if hparams.weight_prec is not None and hparams.weight_prec > 8:
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )
        jax_precision = jax.lax.Precision.DEFAULT

        if self.strides is None:
            strides = (1, ) * (inputs.ndim - 2)
        else:
            strides = self.strides

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count, self.features)
        kernel = self.param('kernel', self.kernel_init, kernel_shape)

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        # Activation quantization
        if hparams.quant_act is not None:
            inputs = QuantOps.create_inputs_fake_quant(
                inputs=inputs,
                hparams=hparams.quant_act,
                get_bounds_params=get_bounds.GetBounds.Params(
                    update_bounds=self.quant_context.update_bounds,
                    update_stats=self.train,
                    paxis_name=self.paxis_name))

        # Weight quantization
        if hparams.weight_prec is not None:
            kernel_reduction_axis = tuple(range(kernel.ndim - 1))
            expected_scale_shape = (1, ) * (kernel.ndim - 1) + (
                self.features, )
            assert hparams.quant_type == QuantType.fake_quant, (
                'we only support fake_quant style of aqt for ConvAqt.')
            quantized_type = hparams.quant_type.to_jax_type()
            kernel = QuantOps.create_weights_fake_quant(
                kernel,
                weight_params=QuantOps.WeightParams(
                    prec=hparams.weight_prec,
                    half_shift=hparams.weight_half_shift,
                    axis=kernel_reduction_axis,
                    expected_scale_shape=expected_scale_shape),
                quantized_type=quantized_type)

        # Convolution
        dimension_numbers = flax.nn.linear._conv_dimension_numbers(
            inputs.shape)  # pylint: disable=protected-access
        metadata_context = contextlib.suppress()
        # Use metadata context to annotate op metadata with quantization info
        act_prec = None if hparams.quant_act is None else hparams.quant_act.prec

        if flags.FLAGS.metadata_enabled:
            metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch(
                weight_prec=hparams.weight_prec, act_prec=act_prec)
        with metadata_context:
            y = lax.conv_general_dilated(
                inputs,
                kernel,
                strides,
                self.padding,
                lhs_dilation=self.input_dilation,
                rhs_dilation=self.kernel_dilation,
                dimension_numbers=dimension_numbers,
                feature_group_count=self.feature_group_count,
                precision=jax_precision)
        # TODO(shivaniagrawal): create quantized conv general dilated.

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            bias = jnp.asarray(bias, self.dtype)
            # The inputs can have an arbitrary number of spatial dims, so we broadcast
            # the bias to match: (batch_size, spatial_dim,... features)
            # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D)
            # or maybe add error checking (e.g. expect inputs to have rank N, but this
            # may already be checked by lax.conv_general_dialated).
            bias = utils.broadcast_rank(bias, inputs)
            y = y + bias
        return y