Exemplo n.º 1
0
def compute_information_removal_samples_by_squaring(rate_matrix,
                                                    initial_distribution,
                                                    min_exponent=1e-4,
                                                    max_exponent=1e5,
                                                    interpolation_steps=256,
                                                    use_perplexity=False):
    """Compute mutual information using repeated squaring.

  Reduces a bunch of repeated work by evaluating power-of-two exponents using
  repeated squaring, starting from a few different test offsets to fill the
  gaps between powers of two.

  Args:
    rate_matrix: Transition rate matrix of shape [vocab_size, vocab_size]
    initial_distribution: Initial distribution of tokens.
    min_exponent: Smallest non-zero exponent to try.
    max_exponent: Largest exponent to try.
    interpolation_steps: Minimum number of interpolation steps to try.
    use_perplexity: Use conditional perplexity(ish) instead of MI

  Returns:
    exponents: Array of exponents for which we computed relative mutual
      information removal.
    information_removals: Array of the information removal for each exponent.
  """
    # How many powers of two do we need to fill the range?
    powers_of_two = 1 + jnp.ceil(
        jnp.log2(max_exponent) - jnp.log2(min_exponent)).astype(jnp.int32)
    # How many shifts should we evaluate between each power of two? For instance,
    # in addition to evaluating at 1, 2, 4, 8, 16, 32 we might also evaluate at
    # 3/2, 3, 6, 12, 24, 48. Increasing interpolation steps will increase this.
    shifts = jnp.ceil(interpolation_steps / powers_of_two).astype(jnp.int32)

    # Figure out the base exponents (1 and 3/2 in the above example, but there
    # may be more)
    base_exponents = jnp.exp2(
        jnp.log2(min_exponent) + jnp.linspace(0, 1, shifts, endpoint=False))

    def from_base(base_exponent):
        base_matrix = transition_rate_expm(base_exponent * rate_matrix)

        def step(mat, i):
            exponent = base_exponent * (2.0**i)
            info_removal = compute_relative_information_removal(
                mat, initial_distribution, use_perplexity=use_perplexity)
            new_mat = jnp.dot(mat, mat, precision=jax.lax.Precision.HIGHEST)
            new_mat = new_mat / jnp.sum(new_mat, axis=0, keepdims=True)
            return new_mat, (exponent, info_removal)

        _, (exponents,
            info_removals) = jax.lax.scan(step,
                                          init=base_matrix,
                                          xs=jnp.arange(powers_of_two))
        return exponents, info_removals

    exponents, info_removals = jax.lax.map(from_base, base_exponents)
    return exponents.reshape([-1]), info_removals.reshape([-1])
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
Exemplo n.º 3
0
def _exp_pade_4_4_fwd(x):  # pylint: disable=missing-function-docstring
  x = tf.convert_to_tensor(x, dtype_hint=tf.float32)
  raw_x = x
  dtype = dtype_util.as_numpy_dtype(x.dtype)
  inf = np.float32('inf').astype(dtype)

  log2e = np.log(2).astype(dtype)

  n = tf.math.floor(x / log2e)
  x = x - n * log2e

  coeffs_p = np.array([1 / 1680, 1 / 84, 3 / 28, 1 / 2, 1], dtype)
  coeffs_q = np.array([1 / 1680, -1 / 84, 3 / 28, -1 / 2, 1], dtype)
  res = _horner(x, coeffs_p) / _horner(x, coeffs_q)

  if JAX_MODE:
    import jax.numpy as jnp  # pylint: disable=g-import-not-at-top
    res = res * jnp.exp2(n)
  else:
    res = res * (2**n)

  res = tf.where(tf.equal(raw_x, -inf), tf.zeros_like(x), res)
  res = tf.where(tf.equal(raw_x, inf), inf, res)
  return res, res
Exemplo n.º 4
0
    def create_symmetric_fp(
        cls,
        *,
        bounds,
        fp_quant,
    ):
        """Create QuantOps for symmetric clipping to floating-point bounds.

    Args:
      bounds: The upper (and absolute lower) bound to clip the inputs.
      fp_quant: quantization floating-point specification of the target format.

    Returns:
      QuantOps for quantizing/dequantizing signed activations.
    """
        if bounds is None:
            if fp_quant.is_scaled:
                raise ValueError(
                    'bounds can only be None if fp_quant.is_scaled is False.')
            return cls(prec=fp_quant, scale=None, symmetric=True, bounds=None)
        else:
            initial_bounds = bounds
            bounds = jnp.asarray(bounds, SCALE_DTYPE)
            if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
                bounds += jnp.finfo(SCALE_DTYPE).eps  # to avoid log2(0)
            scale = jnp.exp2(
                -jnp.floor(jnp.log2(bounds)))  # Scale to unit binade.
            # NOTE: stop_gradient is needed here to prevent gradient flow through
            # scale when scale is not a constant, but computed as a function of
            # activations or weights.
            scale = lax.stop_gradient(scale)

            return cls(prec=fp_quant,
                       scale=scale,
                       symmetric=True,
                       bounds=initial_bounds)
Exemplo n.º 5
0
def exp2(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.exp2(x))