Пример #1
0
  def test_embed(self, weight_prec):
    # Since the dummy embedding matrix has a row of all zeros, we need 'epsilon'
    # to be added to it before calculating scale factors.
    quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = False
    rng = random.PRNGKey(0)
    x = jnp.arange(4)[None]
    dummy_embedding = jnp.broadcast_to(jnp.arange(4)[Ellipsis, None],
                                       (4, 3)).astype(jnp.float32)
    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=3,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=None,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        embedding_init=lambda _rng, _shape: dummy_embedding,
        train=False,
        paxis_name=None,
        quant_context=quant_config.QuantContext(update_bounds=False),
    )
    y, state = embed_module.init_with_output(rng, x)
    test_utils.assert_all_close_prec(dummy_embedding[None], y, weight_prec)

    z = embed_module.apply(
        state, jnp.ones((1, 3)), padding_mask=None, method=embed_module.attend)
    test_utils.assert_all_close_prec(3. * jnp.arange(4), z[0, Ellipsis], weight_prec)
 def test_float_weights_quantization(self, prec):
     # Tests that quantized and rescaled float weights are close to original
     # weights.
     weights = jnp.array(
         fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1))))
     rescaled_weights = QuantOps.create_weights_fake_quant(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=None))
     test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)
Пример #3
0
  def test_float_weights_should_give_close_output(self, weight_prec):
    inputs = random.uniform(self.rng_key, shape=(2, 3))
    model, state = self.init_model_with_1_layer(
        inputs, num_features=4, weight_prec=weight_prec)
    float_weights = jnp.linspace(-1 / 3, 1 / 3, num=12).reshape((3, 4))

    exp_output_without_quant = jnp.matmul(inputs, float_weights)
    state = state.unfreeze()
    state['params']['kernel'] = float_weights
    state = flax.core.freeze(state)
    outputs_with_quant = model.apply(state, inputs, padding_mask=None)
    onp.testing.assert_raises(AssertionError, onp.testing.assert_array_equal,
                              outputs_with_quant, exp_output_without_quant)
    test_utils.assert_all_close_prec(exp_output_without_quant,
                                     outputs_with_quant, weight_prec)
Пример #4
0
 def test_group_conv(self, weight_prec=None):
   x = jnp.ones((1, 8, 8, 4))
   conv_module = flax_layers.ConvAqt(
       features=4,
       kernel_size=(3, 3),
       feature_group_count=2,
       padding='VALID',
       paxis_name='batch',
       quant_context=quant_config.QuantContext(update_bounds=False),
       train=False,
       hparams=flax_layers.ConvAqt.HParams(
           weight_prec=weight_prec,
           quant_act=None,
           quant_type=QuantType.fake_quant),
       kernel_init=initializers.ones,
       bias_init=initializers.ones,
       dtype=jnp.float32)
   y, state = conv_module.init_with_output(self.rng_key, x)
   self.assertEqual(state['params']['kernel'].shape, (3, 3, 2, 4))
   test_utils.assert_all_close_prec(y, onp.full((1, 6, 6, 4), 19.),
                                    weight_prec)