def test_scale_invariance_signed_activation_quantization(self, prec): # Scaling activation by power of 2 and bounds by same factor, # should scale the output by the same scale. activations = random.uniform(random.PRNGKey(0), (10, 1)) act_scale = 8. scaled_activations = activations * act_scale bounds = 6. activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) scaled_activations = QuantOps.create_inputs_fake_quant( inputs=scaled_activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds * act_scale, prec=prec)) onp.testing.assert_array_equal(activations * act_scale, scaled_activations)
def test_int_symmetric_act_quantization(self, prec): # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1 # quantizes correctly. bounds = 2**(prec - 1) - 1 activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_int_positive_act_quantization(self, prec): # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec # quantizes correctly. upper_bound = 2**(prec - 3) activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.positive, bounds=upper_bound, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_inputs_scale_shape_is_expected(self): # Inputs quantization inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) bounds = 6.0 expected_inputs_scale_shape = () _ = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=8.0), get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False, expected_bounds_shape=expected_inputs_scale_shape))
def __call__(self, inputs): """Applies a convolution to the inputs with optional quantization. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ hparams = self.hparams if hparams.weight_prec is not None and hparams.weight_prec > 8: raise NotImplementedError( 'If you want to use more than 8bits for quantization, please revisit ' 'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.' ) jax_precision = jax.lax.Precision.DEFAULT if self.strides is None: strides = (1, ) * (inputs.ndim - 2) else: strides = self.strides in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features) kernel = self.param('kernel', self.kernel_init, kernel_shape) inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(kernel, self.dtype) # Activation quantization if hparams.quant_act is not None: inputs = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=hparams.quant_act, get_bounds_params=get_bounds.GetBounds.Params( update_bounds=self.quant_context.update_bounds, update_stats=self.train, paxis_name=self.paxis_name)) # Weight quantization if hparams.weight_prec is not None: kernel_reduction_axis = tuple(range(kernel.ndim - 1)) expected_scale_shape = (1, ) * (kernel.ndim - 1) + ( self.features, ) assert hparams.quant_type == QuantType.fake_quant, ( 'we only support fake_quant style of aqt for ConvAqt.') quantized_type = hparams.quant_type.to_jax_type() kernel = QuantOps.create_weights_fake_quant( kernel, weight_params=QuantOps.WeightParams( prec=hparams.weight_prec, half_shift=hparams.weight_half_shift, axis=kernel_reduction_axis, expected_scale_shape=expected_scale_shape), quantized_type=quantized_type) # Convolution dimension_numbers = flax.nn.linear._conv_dimension_numbers( inputs.shape) # pylint: disable=protected-access metadata_context = contextlib.suppress() # Use metadata context to annotate op metadata with quantization info act_prec = None if hparams.quant_act is None else hparams.quant_act.prec if flags.FLAGS.metadata_enabled: metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch( weight_prec=hparams.weight_prec, act_prec=act_prec) with metadata_context: y = lax.conv_general_dilated( inputs, kernel, strides, self.padding, lhs_dilation=self.input_dilation, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=jax_precision) # TODO(shivaniagrawal): create quantized conv general dilated. # bias if self.use_bias: bias = self.param('bias', self.bias_init, (self.features, )) bias = jnp.asarray(bias, self.dtype) # The inputs can have an arbitrary number of spatial dims, so we broadcast # the bias to match: (batch_size, spatial_dim,... features) # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D) # or maybe add error checking (e.g. expect inputs to have rank N, but this # may already be checked by lax.conv_general_dialated). bias = utils.broadcast_rank(bias, inputs) y = y + bias return y