def test_floor_and_clip_to_unsigned_int(self, prec): x = jnp.array(fp32(2.0**5 * onp.random.uniform(0, 1.0, size=(1024, 1)))) y = primitives.floor_and_clip_to_unsigned_int( x, prec=prec, dtype=x.dtype, half_shift=False) self.assertGreaterEqual(onp.min(y), 0.0) self.assertLessEqual(onp.max(y), fp32(2**prec - 1)) onp.testing.assert_allclose(y, onp.around(y))
def assert_is_integer_in_range(self, x, *, prec, distribution): if distribution == 'symmetric': x_clipped = primitives.round_and_clip_to_signed_int( x, prec=prec, dtype=x.dtype, half_shift=False) elif distribution == 'positive': x_clipped = primitives.floor_and_clip_to_unsigned_int( x, prec=prec, dtype=x.dtype, half_shift=False) else: raise ValueError(f'Invalid distribution {distribution}') onp.testing.assert_array_equal( x, x_clipped, f'Array cannot be losslessly cast to integer with precision {prec} ' f'and {distribution} distribution.')
def grad_fn(x): return jnp.sum( primitives.floor_and_clip_to_unsigned_int(x, prec=8, dtype=x.dtype, half_shift=False))