Example #1
0
  def test_masking(self):
    # We will simulate a situation where we have two batches with two tokens
    # each, and the second token of the second batch is padding. The channel
    # dimension is three.

    stats = Stats.stats_initializer(shape=(1, 1, 2))
    # The shape of 'x' is [batch index, token index, channel index]
    x = jnp.reshape(jnp.arange(8), (2, 2, 2)).astype(jnp.float32)
    token_mask = jnp.array([[True, True], [True, False]])
    mask = token_mask[Ellipsis,
                      None]  # Broadcast the mask over the channel dimension
    stats = Stats.create_updated_stats(stats, x, axis=(0, 1), mask=mask)
    exp_mean_ch0 = (0 + 2 + 4) / 3
    exp_mean_ch1 = (1 + 3 + 5) / 3
    exp_mean = jnp.array([[[exp_mean_ch0, exp_mean_ch1]]])
    onp.testing.assert_allclose(stats.mean, exp_mean)
    onp.testing.assert_allclose(stats.mean_abs, exp_mean)
    exp_mean_sq_ch0 = (0**2 + 2**2 + 4**2) / 3
    exp_mean_sq_ch1 = (1**2 + 3**2 + 5**2) / 3
    exp_mean_sq = jnp.array([[[exp_mean_sq_ch0, exp_mean_sq_ch1]]])
    onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq)
    exp_max = [[[4, 5]]]
    onp.testing.assert_allclose(stats.mean_batch_maximum, exp_max)
    exp_min = [[[0, 1]]]
    onp.testing.assert_allclose(stats.mean_batch_minimum, exp_min)

    # Now do the same, but with axis=None
    stats = Stats.stats_initializer(shape=())
    stats = Stats.create_updated_stats(stats, x, axis=None, mask=mask)
    exp_mean = (0 + 1 + 2 + 3 + 4 + 5) / 6
    onp.testing.assert_allclose(stats.mean, [[[exp_mean]]])
    exp_mean_sq = (0**2 + 1**2 + 2**2 + 3**2 + 4**2 + 5**2) / 6
    onp.testing.assert_allclose(stats.mean_sq, [[[exp_mean_sq]]])
    onp.testing.assert_allclose(stats.mean_batch_maximum, [[[5]]])
    onp.testing.assert_allclose(stats.mean_batch_minimum, [[[0]]])

    # Also try with reduction axis equal to the broadcasting axis.
    # In this case, we expect a 0 when taking the mean over the
    # array slice that consists solely of masked elements, since only masked
    # elements will not update the initial value of 0.
    stats = Stats.stats_initializer(shape=(2, 2, 1))
    stats = Stats.create_updated_stats(stats, x, axis=(2,), mask=mask)
    exp_mean = [[[(0 + 1) / 2], [(2 + 3) / 2]], [[(4 + 5) / 2], [0]]]
    onp.testing.assert_allclose(stats.mean, exp_mean)
    exp_mean_sq = [[[(0**2 + 1**2) / 2], [(2**2 + 3**2) / 2]],
                   [[(4**2 + 5**2) / 2], [0]]]
    onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq)
    exp_max = [[[1], [3]], [[5], [0]]]
    onp.testing.assert_allclose(stats.mean_batch_maximum, exp_max)
    exp_min = [[[0], [2]], [[4], [0]]]
    onp.testing.assert_allclose(stats.mean_batch_minimum, exp_min)
Example #2
0
  def test_per_channel_average(self):
    """Stats should be different per channel."""
    stats = Stats.stats_initializer(shape=())
    for i in range(-1, 4):
      x = i * jnp.array([[1., 2.], [2., 4.], [3., 6.]])
      stats = Stats.create_updated_stats(stats, x, axis=(0,))
    self.assertEqual(stats.n, 5)

    # For i in range(-1, 4), ith array would be
    # [[i    , 2 * i]
    #  [i * 2, 2 * (i * 2)]
    #. [i * 3, 2 * (i * 3)]]

    exp_mean_ch0 = (-1 + 0 + 1 + 2 + 3) * (1 + 2 + 3) / 15.
    exp_mean_ch1 = (-2 + 0 + 2 + 4 + 6) * (1 + 2 + 3) / 15.
    exp_mean = jnp.array([[exp_mean_ch0, exp_mean_ch1]])
    onp.testing.assert_allclose(stats.mean, exp_mean)

    exp_mean_abs_ch0 = (1 + 0 + 1 + 2 + 3) * (1 + 2 + 3) / 15.
    exp_mean_abs_ch1 = (2 + 0 + 2 + 4 + 6) * (1 + 2 + 3) / 15.
    exp_mean_abs = jnp.array([[exp_mean_abs_ch0, exp_mean_abs_ch1]])
    onp.testing.assert_allclose(stats.mean_abs, exp_mean_abs)

    exp_mean_sq_ch0 = (
        (-1)**2 + 0**2 + 1**2 + 2**2 + 3**2) * (1**2 + 2**2 + 3**2) / 15.
    exp_mean_sq_ch1 = (
        (-2)**2 + 0**2 + 2**2 + 4**2 + 6**2) * (1**2 + 2**2 + 3**2) / 15.
    exp_mean_sq = jnp.array([[exp_mean_sq_ch0, exp_mean_sq_ch1]])
    onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq)
Example #3
0
 def test_exclude_zeros(self, x, mask, axis, exp_mean):
     """Stats should be different when excluding zeros."""
     stats = Stats.stats_initializer(shape=())
     stats = Stats.create_updated_stats(stats,
                                        x,
                                        axis=axis,
                                        mask=mask,
                                        exclude_zeros=True)
     onp.testing.assert_allclose(stats.mean, exp_mean)
Example #4
0
  def test_update_stats_with_different_axes(self, axis, sample_shape,
                                            stats_shape):
    stats = Stats.stats_initializer(shape=stats_shape)
    for i in range(-1, 4):
      x = i * jnp.ones(sample_shape)
      stats = Stats.create_updated_stats(stats, x, axis=axis)
    self.assertEqual(stats.n, 5)

    exp_mean = (-1 + 0 + 1 + 2 + 3) / 5. * jnp.ones(stats_shape)
    onp.testing.assert_allclose(stats.mean, exp_mean)

    exp_mean_abs = (1 + 0 + 1 + 2 + 3) / 5. * jnp.ones(stats_shape)
    onp.testing.assert_allclose(stats.mean_abs, exp_mean_abs)

    exp_mean_sq = (
        (-1)**2 + 0**2 + 1**2 + 2**2 + 3**2) / 5. * jnp.ones(stats_shape)
    onp.testing.assert_allclose(stats.mean_sq, exp_mean_sq)
    def __call__(
        self,
        x,
        *,
        bounds_params,
    ):
        """Compute the input batch statistics.

    Args:
      x: the input to get bounds from using statistics.
      bounds_params: parameters to compute input's statistics and bounds.

    Returns:
      Bound value (same shape as inputs).
    """

        if bounds_params.mask is not None:
            shape_utils.assert_shapes_compatible(x.shape,
                                                 bounds_params.mask.shape)

        x = jnp.asarray(x, jnp.float32)

        hyper = self.hyper
        is_initializing = not self.has_variable('get_bounds', 'stats')

        if hyper.granularity == quant_config.QuantGranularity.per_tensor:
            # Equivalently, this could be written as
            # quant_axis = tuple(range(x.ndim))
            quant_axis = None
            stats_shape = (1, ) * len(x.shape)
        elif hyper.granularity == quant_config.QuantGranularity.per_channel:
            # Quantize by aggregating activation statistics across all dimensions of
            # the activation tensor EXCEPT the last dimension, which we interpret as
            # the channel dimension. For example, in a transformer context, x might
            # have a shape corresponding to [example, token, channel], in which case
            # this aggregates activation statistics separately for each feature, where
            # for each feature it aggregates over all unmasked tokens in all examples.
            quant_axis = tuple(range(x.ndim - 1))
            stats_shape = (1, ) * (x.ndim - 1) + (x.shape[-1], )
        else:
            raise ValueError(f'Unknown granularity {hyper.granularity}')

        stats_state = self.variable('get_bounds', 'stats',
                                    Stats.stats_initializer, stats_shape)

        def bound_initializer(shape):
            return hyper.initial_bound * jnp.ones(shape)

        bounds = self.variable('get_bounds', 'bounds', bound_initializer,
                               stats_shape)

        if bounds_params.update_stats and not is_initializing:
            stats_state.value = Stats.create_updated_stats(
                stats_state.value,
                x,
                mask=bounds_params.mask,
                axis=quant_axis,
                paxis_name=bounds_params.paxis_name,
                alpha=hyper.ema_coeff,
                exclude_zeros=hyper.exclude_zeros)

        if bounds_params.update_bounds and not is_initializing:
            bounds.value = self._stats_to_bounds(stats_state.value)
            if hyper.reset_stats:
                stats_state.value = Stats.stats_initializer(stats_shape)
        return bounds.value