def test_check_value_inside_and_outside_of_context_conv_general( self, weight_prec): original_op_name = 'conv_general_dilated' # The 'name' in primitive should change in the context in 'flax_layers' # if the context is enabled self.assertEqual(original_op_name, lax.conv_general_dilated_p.name) with compute_cost_utils.ConvMetadataMonkeyPatch( weight_prec=weight_prec, act_prec=None): self.assertNotEqual(original_op_name, lax.conv_general_dilated_p.name) self.assertEqual(original_op_name, lax.conv_general_dilated_p.name)
def __call__(self, inputs): """Applies a convolution to the inputs with optional quantization. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ hparams = self.hparams if hparams.weight_prec is not None and hparams.weight_prec > 8: raise NotImplementedError( 'If you want to use more than 8bits for quantization, please revisit ' 'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.' ) jax_precision = jax.lax.Precision.DEFAULT if self.strides is None: strides = (1, ) * (inputs.ndim - 2) else: strides = self.strides in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features) kernel = self.param('kernel', self.kernel_init, kernel_shape) inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(kernel, self.dtype) # Activation quantization if hparams.quant_act is not None: inputs = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=hparams.quant_act, get_bounds_params=get_bounds.GetBounds.Params( update_bounds=self.quant_context.update_bounds, update_stats=self.train, paxis_name=self.paxis_name)) # Weight quantization if hparams.weight_prec is not None: kernel_reduction_axis = tuple(range(kernel.ndim - 1)) expected_scale_shape = (1, ) * (kernel.ndim - 1) + ( self.features, ) assert hparams.quant_type == QuantType.fake_quant, ( 'we only support fake_quant style of aqt for ConvAqt.') quantized_type = hparams.quant_type.to_jax_type() kernel = QuantOps.create_weights_fake_quant( kernel, weight_params=QuantOps.WeightParams( prec=hparams.weight_prec, half_shift=hparams.weight_half_shift, axis=kernel_reduction_axis, expected_scale_shape=expected_scale_shape), quantized_type=quantized_type) # Convolution dimension_numbers = flax.nn.linear._conv_dimension_numbers( inputs.shape) # pylint: disable=protected-access metadata_context = contextlib.suppress() # Use metadata context to annotate op metadata with quantization info act_prec = None if hparams.quant_act is None else hparams.quant_act.prec if flags.FLAGS.metadata_enabled: metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch( weight_prec=hparams.weight_prec, act_prec=act_prec) with metadata_context: y = lax.conv_general_dilated( inputs, kernel, strides, self.padding, lhs_dilation=self.input_dilation, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=jax_precision) # TODO(shivaniagrawal): create quantized conv general dilated. # bias if self.use_bias: bias = self.param('bias', self.bias_init, (self.features, )) bias = jnp.asarray(bias, self.dtype) # The inputs can have an arbitrary number of spatial dims, so we broadcast # the bias to match: (batch_size, spatial_dim,... features) # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D) # or maybe add error checking (e.g. expect inputs to have rank N, but this # may already be checked by lax.conv_general_dialated). bias = utils.broadcast_rank(bias, inputs) y = y + bias return y