def test_embed(self, weight_prec): # Since the dummy embedding matrix has a row of all zeros, we need 'epsilon' # to be added to it before calculating scale factors. quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = False rng = random.PRNGKey(0) x = jnp.arange(4)[None] dummy_embedding = jnp.broadcast_to(jnp.arange(4)[Ellipsis, None], (4, 3)).astype(jnp.float32) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=3, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant, weight_half_shift=False), embedding_init=lambda _rng, _shape: dummy_embedding, train=False, paxis_name=None, quant_context=quant_config.QuantContext(update_bounds=False), ) y, state = embed_module.init_with_output(rng, x) test_utils.assert_all_close_prec(dummy_embedding[None], y, weight_prec) z = embed_module.apply( state, jnp.ones((1, 3)), padding_mask=None, method=embed_module.attend) test_utils.assert_all_close_prec(3. * jnp.arange(4), z[0, Ellipsis], weight_prec)
def test_float_weights_quantization(self, prec): # Tests that quantized and rescaled float weights are close to original # weights. weights = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1)))) rescaled_weights = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)
def test_float_weights_should_give_close_output(self, weight_prec): inputs = random.uniform(self.rng_key, shape=(2, 3)) model, state = self.init_model_with_1_layer( inputs, num_features=4, weight_prec=weight_prec) float_weights = jnp.linspace(-1 / 3, 1 / 3, num=12).reshape((3, 4)) exp_output_without_quant = jnp.matmul(inputs, float_weights) state = state.unfreeze() state['params']['kernel'] = float_weights state = flax.core.freeze(state) outputs_with_quant = model.apply(state, inputs, padding_mask=None) onp.testing.assert_raises(AssertionError, onp.testing.assert_array_equal, outputs_with_quant, exp_output_without_quant) test_utils.assert_all_close_prec(exp_output_without_quant, outputs_with_quant, weight_prec)
def test_group_conv(self, weight_prec=None): x = jnp.ones((1, 8, 8, 4)) conv_module = flax_layers.ConvAqt( features=4, kernel_size=(3, 3), feature_group_count=2, padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) y, state = conv_module.init_with_output(self.rng_key, x) self.assertEqual(state['params']['kernel'].shape, (3, 3, 2, 4)) test_utils.assert_all_close_prec(y, onp.full((1, 6, 6, 4), 19.), weight_prec)