def test_attributes_create_symmetric(self, bounds, prec):
     bounds = jnp.array(bounds)
     act_signed = QuantOps.create_symmetric(bounds=bounds, prec=prec)
     onp.testing.assert_array_equal(act_signed._scale,
                                    (2**(prec - 1) - 1) / bounds)
     self.assertEqual(act_signed._symmetric, True)
     self.assertEqual(act_signed._prec, prec)
 def test_per_feature_dim_unsigned_activation_quantization_clips_outside_bounds(
         self, prec):
     # Activation values less than -upper_bound get clipped to -upper_bound, and
     # values greater than upper_bound get clipped to upper_bound
     act_quant = QuantOps.create_symmetric(bounds=jnp.array([[6.0, 8.0]]),
                                           prec=prec)
     activation = jnp.array(fp32([[-7, -8.9], [6.2, 9.4], [0, 0.]]))
     quantized_activations = act_quant.to_quantized(activation,
                                                    dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(
         quantized_activations,
         jnp.array([[-2**(prec - 1.0) + 1.0], [2**(prec - 1.0) - 1.0],
                    [0.0]]) * jnp.array([[1., 1.]]))
     activations = act_quant.from_quantized(quantized_activations,
                                            dtype=jnp.float32)
     onp.testing.assert_array_equal(activations,
                                    [[-6.0, -8.0], [6.0, 8.], [0, 0.]])