def test_per_channel_average(self): """Stats should be different per channel.""" stats = Stats.stats_initializer(shape=()) for i in range(-1, 4): x = i * jnp.array([[1., 2.], [2., 4.], [3., 6.]]) stats = Stats.create_updated_stats(stats, x, axis=(0,)) self.assertEqual(stats.n, 5) # For i in range(-1, 4), ith array would be # [[i , 2 * i] # [i * 2, 2 * (i * 2)] #. [i * 3, 2 * (i * 3)]] exp_mean_ch0 = (-1 + 0 + 1 + 2 + 3) * (1 + 2 + 3) / 15. exp_mean_ch1 = (-2 + 0 + 2 + 4 + 6) * (1 + 2 + 3) / 15. exp_mean = jnp.array([[exp_mean_ch0, exp_mean_ch1]]) onp.testing.assert_allclose(stats.mean, exp_mean) exp_mean_abs_ch0 = (1 + 0 + 1 + 2 + 3) * (1 + 2 + 3) / 15. exp_mean_abs_ch1 = (2 + 0 + 2 + 4 + 6) * (1 + 2 + 3) / 15. exp_mean_abs = jnp.array([[exp_mean_abs_ch0, exp_mean_abs_ch1]]) onp.testing.assert_allclose(stats.mean_abs, exp_mean_abs) exp_mean_sq_ch0 = ( (-1)**2 + 0**2 + 1**2 + 2**2 + 3**2) * (1**2 + 2**2 + 3**2) / 15. exp_mean_sq_ch1 = ( (-2)**2 + 0**2 + 2**2 + 4**2 + 6**2) * (1**2 + 2**2 + 3**2) / 15. exp_mean_sq = jnp.array([[exp_mean_sq_ch0, exp_mean_sq_ch1]]) onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq)
def test_exclude_zeros(self, x, mask, axis, exp_mean): """Stats should be different when excluding zeros.""" stats = Stats.stats_initializer(shape=()) stats = Stats.create_updated_stats(stats, x, axis=axis, mask=mask, exclude_zeros=True) onp.testing.assert_allclose(stats.mean, exp_mean)
def test_update_stats_with_different_axes(self, axis, sample_shape, stats_shape): stats = Stats.stats_initializer(shape=stats_shape) for i in range(-1, 4): x = i * jnp.ones(sample_shape) stats = Stats.create_updated_stats(stats, x, axis=axis) self.assertEqual(stats.n, 5) exp_mean = (-1 + 0 + 1 + 2 + 3) / 5. * jnp.ones(stats_shape) onp.testing.assert_allclose(stats.mean, exp_mean) exp_mean_abs = (1 + 0 + 1 + 2 + 3) / 5. * jnp.ones(stats_shape) onp.testing.assert_allclose(stats.mean_abs, exp_mean_abs) exp_mean_sq = ( (-1)**2 + 0**2 + 1**2 + 2**2 + 3**2) / 5. * jnp.ones(stats_shape) onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq)
def test_get_state_dict_summary(self, keys, expected_summary): state_dict = { 'decoder': { 'attention': { 'dense_out': { 'bounds': jnp.array([[1., 2.], [2., 4.], [3., 6.]]), 'min_per_ch': jnp.array([-6., -5., -4.]), 'max_per_ch': jnp.array([20., 21., 22.]), 'stats': Stats( n=1, mean=jnp.ones(()), mean_abs=jnp.ones(()), mean_sq=jnp.ones(()), mean_batch_maximum=jnp.ones(()), mean_batch_minimum=jnp.ones(())) } }, 'mlp': { 'dense_1': { 'stats': Stats( n=1, mean=jnp.ones(()), mean_abs=jnp.ones(()), mean_sq=jnp.ones(()), mean_batch_maximum=jnp.ones(()), mean_batch_minimum=jnp.ones(())) } }, } } summary = summary_utils.get_state_dict_summary(state_dict, keys=keys) self.assertNestedDictEqual(summary, expected_summary)
def test_masking(self): # We will simulate a situation where we have two batches with two tokens # each, and the second token of the second batch is padding. The channel # dimension is three. stats = Stats.stats_initializer(shape=(1, 1, 2)) # The shape of 'x' is [batch index, token index, channel index] x = jnp.reshape(jnp.arange(8), (2, 2, 2)).astype(jnp.float32) token_mask = jnp.array([[True, True], [True, False]]) mask = token_mask[Ellipsis, None] # Broadcast the mask over the channel dimension stats = Stats.create_updated_stats(stats, x, axis=(0, 1), mask=mask) exp_mean_ch0 = (0 + 2 + 4) / 3 exp_mean_ch1 = (1 + 3 + 5) / 3 exp_mean = jnp.array([[[exp_mean_ch0, exp_mean_ch1]]]) onp.testing.assert_allclose(stats.mean, exp_mean) onp.testing.assert_allclose(stats.mean_abs, exp_mean) exp_mean_sq_ch0 = (0**2 + 2**2 + 4**2) / 3 exp_mean_sq_ch1 = (1**2 + 3**2 + 5**2) / 3 exp_mean_sq = jnp.array([[[exp_mean_sq_ch0, exp_mean_sq_ch1]]]) onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq) exp_max = [[[4, 5]]] onp.testing.assert_allclose(stats.mean_batch_maximum, exp_max) exp_min = [[[0, 1]]] onp.testing.assert_allclose(stats.mean_batch_minimum, exp_min) # Now do the same, but with axis=None stats = Stats.stats_initializer(shape=()) stats = Stats.create_updated_stats(stats, x, axis=None, mask=mask) exp_mean = (0 + 1 + 2 + 3 + 4 + 5) / 6 onp.testing.assert_allclose(stats.mean, [[[exp_mean]]]) exp_mean_sq = (0**2 + 1**2 + 2**2 + 3**2 + 4**2 + 5**2) / 6 onp.testing.assert_allclose(stats.mean_sq, [[[exp_mean_sq]]]) onp.testing.assert_allclose(stats.mean_batch_maximum, [[[5]]]) onp.testing.assert_allclose(stats.mean_batch_minimum, [[[0]]]) # Also try with reduction axis equal to the broadcasting axis. # In this case, we expect a 0 when taking the mean over the # array slice that consists solely of masked elements, since only masked # elements will not update the initial value of 0. stats = Stats.stats_initializer(shape=(2, 2, 1)) stats = Stats.create_updated_stats(stats, x, axis=(2,), mask=mask) exp_mean = [[[(0 + 1) / 2], [(2 + 3) / 2]], [[(4 + 5) / 2], [0]]] onp.testing.assert_allclose(stats.mean, exp_mean) exp_mean_sq = [[[(0**2 + 1**2) / 2], [(2**2 + 3**2) / 2]], [[(4**2 + 5**2) / 2], [0]]] onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq) exp_max = [[[1], [3]], [[5], [0]]] onp.testing.assert_allclose(stats.mean_batch_maximum, exp_max) exp_min = [[[0], [2]], [[4], [0]]] onp.testing.assert_allclose(stats.mean_batch_minimum, exp_min)
def test_create_empty_stats_using_initializer(self): stats = Stats.stats_initializer(shape=()) self.assertEqual(stats.n, 0) self.assertEqual(stats.mean, 0.) self.assertEqual(stats.mean_abs, 0.) self.assertEqual(stats.mean_sq, 0.)
def __call__( self, x, *, bounds_params, ): """Compute the input batch statistics. Args: x: the input to get bounds from using statistics. bounds_params: parameters to compute input's statistics and bounds. Returns: Bound value (same shape as inputs). """ if bounds_params.mask is not None: shape_utils.assert_shapes_compatible(x.shape, bounds_params.mask.shape) x = jnp.asarray(x, jnp.float32) hyper = self.hyper is_initializing = not self.has_variable('get_bounds', 'stats') if hyper.granularity == quant_config.QuantGranularity.per_tensor: # Equivalently, this could be written as # quant_axis = tuple(range(x.ndim)) quant_axis = None stats_shape = (1, ) * len(x.shape) elif hyper.granularity == quant_config.QuantGranularity.per_channel: # Quantize by aggregating activation statistics across all dimensions of # the activation tensor EXCEPT the last dimension, which we interpret as # the channel dimension. For example, in a transformer context, x might # have a shape corresponding to [example, token, channel], in which case # this aggregates activation statistics separately for each feature, where # for each feature it aggregates over all unmasked tokens in all examples. quant_axis = tuple(range(x.ndim - 1)) stats_shape = (1, ) * (x.ndim - 1) + (x.shape[-1], ) else: raise ValueError(f'Unknown granularity {hyper.granularity}') stats_state = self.variable('get_bounds', 'stats', Stats.stats_initializer, stats_shape) def bound_initializer(shape): return hyper.initial_bound * jnp.ones(shape) bounds = self.variable('get_bounds', 'bounds', bound_initializer, stats_shape) if bounds_params.update_stats and not is_initializing: stats_state.value = Stats.create_updated_stats( stats_state.value, x, mask=bounds_params.mask, axis=quant_axis, paxis_name=bounds_params.paxis_name, alpha=hyper.ema_coeff, exclude_zeros=hyper.exclude_zeros) if bounds_params.update_bounds and not is_initializing: bounds.value = self._stats_to_bounds(stats_state.value) if hyper.reset_stats: stats_state.value = Stats.stats_initializer(stats_shape) return bounds.value