Example #1
0
def quantized_sum(
        x,  #
        axis,
        keepdims,
        prec):
    """Sums a tensor while quantizing intermediate accumulations.

  This is almost a drop-in replacement for jnp.sum. It only differs in that it
  takes in an 'act_hparams' parameter that controls the quantization of
  intermediate accumulations during the reduction.

  Arguments:
    x: Input, a Jax array
    axis: Which axes to reduce over (see jnp.sum docs)
    keepdims: Whether to keep of drop axes that are reduced (see jnp.sum docs)
    prec: Precision to quantize intermediate to. Currently can only an instance
      of QuantOps.FloatQuant.FloatPrec, corresponding to an unscaled
      floating-point format, or it can be None to indicate no quantization
      should be applied.

  Returns:
    A Jax array with the quantized sum of 'x'.

  """

    # Don't quantize. In this case, this function just wraps jnp.sum.
    if prec is None:
        return jnp.sum(x, axis=axis, keepdims=keepdims)

    # We bypass QuantOps.create_input_ops and directly call
    # QuantOps.create_symmetric_fp because the former creates an instance of
    # GetBounds, which in turn creates state variables to store activation
    # statistics. We do not want to compute statistics for each individual
    # addition within the sum reduction.
    fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
    quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None)

    if not isinstance(axis, Iterable):
        axis = (axis, )
    axis = utils.normalize_axes(axis, x.ndim)
    dtype = x.dtype

    zero = jnp.zeros((), dtype=dtype)
    x_quantized_sum = lax.reduce(
        x,
        init_values=zero,
        computation=lambda a, b: quant_ops.to_quantized(a + b, dtype=dtype),
        dimensions=axis)

    if keepdims:
        x_quantized_sum = jnp.expand_dims(x_quantized_sum, axis)

    return x_quantized_sum
Example #2
0
    def __call__(
        self,
        x,
        *,
        mask,
    ):
        """Applies a tag to track distributions.

    Args:
      x: the array to compute statistics distributions over.
      mask: boolean array indicating which elements of 'x' should be
        included in the stats calculation ('True' means to include).

    Returns:
      x unchanged. The return value can also be ignored.
    """
        if mask is None:
            mask = jnp.full(x.shape, True)
        shape_utils.assert_shapes_compatible(x.shape, mask.shape)
        mask = jnp.broadcast_to(mask, x.shape)
        channel_axis = self.channel_axis
        if channel_axis is not None:
            if not isinstance(channel_axis, Iterable):
                channel_axis = (channel_axis, )
            channel_axis = normalize_axes(channel_axis, x.ndim)
            x = _take_subset_of_axes(
                x,
                axis=channel_axis,
                num_indices_per_ax=self.num_indices_per_ax)
            mask = _take_subset_of_axes(
                mask,
                axis=channel_axis,
                num_indices_per_ax=self.num_indices_per_ax)
            reduction_axis = tuple(
                [ax for ax in range(x.ndim) if ax not in channel_axis])
        else:
            reduction_axis = None

        distr_shape = ()
        if channel_axis:
            distr_shape = tuple(d for i, d in enumerate(x.shape)
                                if i in channel_axis)

        # TODO(wanglisa): Consider adding configurability to specify which
        # statistics are collected.
        init_with_zeros = lambda shape: jnp.zeros(shape, dtype=jnp.float32)
        is_initializing = not self.has_variable('stats_tag', 'min_per_ch')
        min_per_ch = self.variable(
            'stats_tag',
            'min_per_ch',
            init_with_zeros,
            distr_shape,
        )
        max_per_ch = self.variable('stats_tag', 'max_per_ch', init_with_zeros,
                                   distr_shape)
        mean_per_ch = self.variable(
            'stats_tag',
            'mean_per_ch',
            init_with_zeros,
            distr_shape,
        )
        stddev_per_ch = self.variable(
            'stats_tag',
            'stddev_per_ch',
            init_with_zeros,
            distr_shape,
        )
        absdev_per_ch = self.variable(
            'stats_tag',
            'absdev_per_ch',
            init_with_zeros,
            distr_shape,
        )
        stddev_per_ch_uncentered = self.variable(
            'stats_tag',
            'stddev_per_ch_uncentered',
            init_with_zeros,
            distr_shape,
        )
        absdev_per_ch_uncentered = self.variable(
            'stats_tag',
            'absdev_per_ch_uncentered',
            init_with_zeros,
            distr_shape,
        )
        if self.update_stats and not is_initializing:
            min_per_ch.value = jnp.min(jnp.where(mask, x, math.inf),
                                       axis=reduction_axis)
            max_per_ch.value = jnp.max(jnp.where(mask, x, -math.inf),
                                       axis=reduction_axis)
            mean_per_ch_keepdims = stats.masked_mean(x,
                                                     mask=mask,
                                                     axis=reduction_axis,
                                                     paxis_name=None,
                                                     keepdims=True)
            mean_per_ch.value = mean_per_ch_keepdims.squeeze(
                axis=reduction_axis)
            stddev_per_ch.value = jnp.sqrt(
                stats.masked_mean((x - mean_per_ch_keepdims)**2,
                                  mask=mask,
                                  axis=reduction_axis,
                                  paxis_name=None,
                                  keepdims=False))
            absdev_per_ch.value = stats.masked_mean(
                jnp.abs(x - mean_per_ch_keepdims),
                mask=mask,
                axis=reduction_axis,
                paxis_name=None,
                keepdims=False)
            stddev_per_ch_uncentered.value = jnp.sqrt(
                stats.masked_mean(jnp.square(x),
                                  mask=mask,
                                  axis=reduction_axis,
                                  paxis_name=None,
                                  keepdims=False))
            absdev_per_ch_uncentered.value = stats.masked_mean(
                jnp.abs(x),
                mask=mask,
                axis=reduction_axis,
                paxis_name=None,
                keepdims=False)