def quantized_sum( x, # axis, keepdims, prec): """Sums a tensor while quantizing intermediate accumulations. This is almost a drop-in replacement for jnp.sum. It only differs in that it takes in an 'act_hparams' parameter that controls the quantization of intermediate accumulations during the reduction. Arguments: x: Input, a Jax array axis: Which axes to reduce over (see jnp.sum docs) keepdims: Whether to keep of drop axes that are reduced (see jnp.sum docs) prec: Precision to quantize intermediate to. Currently can only an instance of QuantOps.FloatQuant.FloatPrec, corresponding to an unscaled floating-point format, or it can be None to indicate no quantization should be applied. Returns: A Jax array with the quantized sum of 'x'. """ # Don't quantize. In this case, this function just wraps jnp.sum. if prec is None: return jnp.sum(x, axis=axis, keepdims=keepdims) # We bypass QuantOps.create_input_ops and directly call # QuantOps.create_symmetric_fp because the former creates an instance of # GetBounds, which in turn creates state variables to store activation # statistics. We do not want to compute statistics for each individual # addition within the sum reduction. fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) if not isinstance(axis, Iterable): axis = (axis, ) axis = utils.normalize_axes(axis, x.ndim) dtype = x.dtype zero = jnp.zeros((), dtype=dtype) x_quantized_sum = lax.reduce( x, init_values=zero, computation=lambda a, b: quant_ops.to_quantized(a + b, dtype=dtype), dimensions=axis) if keepdims: x_quantized_sum = jnp.expand_dims(x_quantized_sum, axis) return x_quantized_sum
def __call__( self, x, *, mask, ): """Applies a tag to track distributions. Args: x: the array to compute statistics distributions over. mask: boolean array indicating which elements of 'x' should be included in the stats calculation ('True' means to include). Returns: x unchanged. The return value can also be ignored. """ if mask is None: mask = jnp.full(x.shape, True) shape_utils.assert_shapes_compatible(x.shape, mask.shape) mask = jnp.broadcast_to(mask, x.shape) channel_axis = self.channel_axis if channel_axis is not None: if not isinstance(channel_axis, Iterable): channel_axis = (channel_axis, ) channel_axis = normalize_axes(channel_axis, x.ndim) x = _take_subset_of_axes( x, axis=channel_axis, num_indices_per_ax=self.num_indices_per_ax) mask = _take_subset_of_axes( mask, axis=channel_axis, num_indices_per_ax=self.num_indices_per_ax) reduction_axis = tuple( [ax for ax in range(x.ndim) if ax not in channel_axis]) else: reduction_axis = None distr_shape = () if channel_axis: distr_shape = tuple(d for i, d in enumerate(x.shape) if i in channel_axis) # TODO(wanglisa): Consider adding configurability to specify which # statistics are collected. init_with_zeros = lambda shape: jnp.zeros(shape, dtype=jnp.float32) is_initializing = not self.has_variable('stats_tag', 'min_per_ch') min_per_ch = self.variable( 'stats_tag', 'min_per_ch', init_with_zeros, distr_shape, ) max_per_ch = self.variable('stats_tag', 'max_per_ch', init_with_zeros, distr_shape) mean_per_ch = self.variable( 'stats_tag', 'mean_per_ch', init_with_zeros, distr_shape, ) stddev_per_ch = self.variable( 'stats_tag', 'stddev_per_ch', init_with_zeros, distr_shape, ) absdev_per_ch = self.variable( 'stats_tag', 'absdev_per_ch', init_with_zeros, distr_shape, ) stddev_per_ch_uncentered = self.variable( 'stats_tag', 'stddev_per_ch_uncentered', init_with_zeros, distr_shape, ) absdev_per_ch_uncentered = self.variable( 'stats_tag', 'absdev_per_ch_uncentered', init_with_zeros, distr_shape, ) if self.update_stats and not is_initializing: min_per_ch.value = jnp.min(jnp.where(mask, x, math.inf), axis=reduction_axis) max_per_ch.value = jnp.max(jnp.where(mask, x, -math.inf), axis=reduction_axis) mean_per_ch_keepdims = stats.masked_mean(x, mask=mask, axis=reduction_axis, paxis_name=None, keepdims=True) mean_per_ch.value = mean_per_ch_keepdims.squeeze( axis=reduction_axis) stddev_per_ch.value = jnp.sqrt( stats.masked_mean((x - mean_per_ch_keepdims)**2, mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False)) absdev_per_ch.value = stats.masked_mean( jnp.abs(x - mean_per_ch_keepdims), mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False) stddev_per_ch_uncentered.value = jnp.sqrt( stats.masked_mean(jnp.square(x), mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False)) absdev_per_ch_uncentered.value = stats.masked_mean( jnp.abs(x), mask=mask, axis=reduction_axis, paxis_name=None, keepdims=False)