Ejemplo n.º 1
0
 def test_positive_activation_quantization_clips_outside_bounds(self, prec):
   # Activation values less than 0 get clipped to 0, and values greater than
   # upper_bound get clipped to upper_bound
   relu6 = QuantOps.create_positive(bounds=6.0, prec=prec)
   activation = jnp.array(fp32([-0.5, 6.2, 3.141]))
   quantized_activations = relu6.to_quantized(activation, dtype=SCALE_DTYPE)
   onp.testing.assert_array_equal(quantized_activations[0:2],
                                  [0.0, 2**prec - 1])
   activations = relu6.from_quantized(quantized_activations, dtype=jnp.float32)
   max_clipped_val = (2**prec - 1) * (6.0 / 2**prec)
   onp.testing.assert_array_equal(activations[0:2], [0.0, max_clipped_val])
    def test_per_feature_dim_scale_invariance_pos_activation_quantization(
            self, prec):
        # Scaling each channel of activations by a different power of 2 and upper
        # bound with same scale, should scale the respective channel of output by
        # the same scale.
        activations = random.uniform(random.PRNGKey(0), (3, 4))
        act_scale = 2**jnp.arange(4)
        scaled_activations = activations * act_scale[jnp.newaxis, :]

        upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32)

        act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec)
        activations = act_quant_ops.fake_quant(activations,
                                               quantized_type=SCALE_DTYPE)

        scaled_act_quant_ops = QuantOps.create_positive(
            bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec)
        scaled_activations = scaled_act_quant_ops.fake_quant(
            scaled_activations, quantized_type=SCALE_DTYPE)
        onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :],
                                       scaled_activations)
 def test_attributes_create_positive(self, bounds, prec):
     bounds = jnp.array(bounds)
     relu6 = QuantOps.create_positive(bounds=bounds, prec=prec)
     onp.testing.assert_array_equal(relu6._scale, 2**prec / bounds)
     self.assertEqual(relu6._symmetric, False)
     self.assertEqual(relu6._prec, prec)