예제 #1
0
def compute_information_removal_samples_by_squaring(rate_matrix,
                                                    initial_distribution,
                                                    min_exponent=1e-4,
                                                    max_exponent=1e5,
                                                    interpolation_steps=256,
                                                    use_perplexity=False):
    """Compute mutual information using repeated squaring.

  Reduces a bunch of repeated work by evaluating power-of-two exponents using
  repeated squaring, starting from a few different test offsets to fill the
  gaps between powers of two.

  Args:
    rate_matrix: Transition rate matrix of shape [vocab_size, vocab_size]
    initial_distribution: Initial distribution of tokens.
    min_exponent: Smallest non-zero exponent to try.
    max_exponent: Largest exponent to try.
    interpolation_steps: Minimum number of interpolation steps to try.
    use_perplexity: Use conditional perplexity(ish) instead of MI

  Returns:
    exponents: Array of exponents for which we computed relative mutual
      information removal.
    information_removals: Array of the information removal for each exponent.
  """
    # How many powers of two do we need to fill the range?
    powers_of_two = 1 + jnp.ceil(
        jnp.log2(max_exponent) - jnp.log2(min_exponent)).astype(jnp.int32)
    # How many shifts should we evaluate between each power of two? For instance,
    # in addition to evaluating at 1, 2, 4, 8, 16, 32 we might also evaluate at
    # 3/2, 3, 6, 12, 24, 48. Increasing interpolation steps will increase this.
    shifts = jnp.ceil(interpolation_steps / powers_of_two).astype(jnp.int32)

    # Figure out the base exponents (1 and 3/2 in the above example, but there
    # may be more)
    base_exponents = jnp.exp2(
        jnp.log2(min_exponent) + jnp.linspace(0, 1, shifts, endpoint=False))

    def from_base(base_exponent):
        base_matrix = transition_rate_expm(base_exponent * rate_matrix)

        def step(mat, i):
            exponent = base_exponent * (2.0**i)
            info_removal = compute_relative_information_removal(
                mat, initial_distribution, use_perplexity=use_perplexity)
            new_mat = jnp.dot(mat, mat, precision=jax.lax.Precision.HIGHEST)
            new_mat = new_mat / jnp.sum(new_mat, axis=0, keepdims=True)
            return new_mat, (exponent, info_removal)

        _, (exponents,
            info_removals) = jax.lax.scan(step,
                                          init=base_matrix,
                                          xs=jnp.arange(powers_of_two))
        return exponents, info_removals

    exponents, info_removals = jax.lax.map(from_base, base_exponents)
    return exponents.reshape([-1]), info_removals.reshape([-1])
예제 #2
0
    def _dp(self, log_potentials, length):
        "Compute forward pass by linear scan"
        semiring = self.semiring
        N, K, C, C2 = log_potentials.shape
        assert C == C2, "Transition shape doesn't match"
        log_N = np.log2(N)
        assert log_N % 1 == 0.0

        # Init.
        init = np.full((N, K-1, K-1, C, C), semiring.zero)
        Cs = np.arange(C)

        init = init.at[:, 0, 0, Cs, Cs].set(semiring.one)

        mask = np.arange(N).reshape(N, 1, 1, 1) < length - 1
        log_potentials = np.where(mask, log_potentials, semiring.zero)
        init = init.at[:, 0].set(np.where(mask, semiring.zero, init[:, 0]))
        start = semiring.sum(np.stack([init[:, :K-1, 0], log_potentials[:, 1:K]], axis=-1))         
        init = init.at[:, :K-1, 0].set(start)        
        end = length - 1
        for k in range(1, K - 1):
            mask = np.arange(N).reshape(N, 1) < end - (k - 1)
            v = np.where(mask, semiring.one, init[:, k - 1, k, Cs, Cs])
            init = init.at[:, k - 1, k, Cs, Cs].set(v)
        
        K_1 = K - 1
        chart = (
            init.transpose((0, 1, 3, 2, 4))
            .reshape(N, K_1 * C, K_1 * C)
        )

        for n in range(int(log_N)):
            chart = semiring.matmul(chart[1::2], chart[0::2])
        chart = chart.reshape(K_1, C, K_1, C)        
        return semiring.sum(semiring.sum(chart[0, :, 0, :]))
예제 #3
0
    def two_normalize(m):
        # Divide m by a power of 2 to get its norm close to 1
        norm = np.linalg.norm(m, axis=(2, 3), keepdims=True)
        two_pow = np.floor(np.log2(norm))
        stable_m = m / (2**two_pow)

        return stable_m, np.sum(two_pow, axis=0)
예제 #4
0
 def loss(batch):
     theta = wavenet(batch)[:, :-1, :]
     # now slice the padding off the batch
     sliced_batch = batch[:, receptive_field:, :]
     return (np.mean(discretized_mix_logistic_loss(
         theta, sliced_batch, num_class=1 << 16),
                     axis=0) * np.log2(np.e) / (output_width - 1))
예제 #5
0
def var_gate_exact(top_state, site, bottom_state):
    '''
    Goal:
        to find argmax_{gate} <top_state | gate | down_state>
        where gate is actting on (site, site+1)
    Input:
        top_state: (did not have conjugation yet!!!)
        site: gate is applying on (site, site+1)
        bottom_state
    Return:
        new_gate
    '''
    total_dim = top_state.size
    L = int(np.log2(total_dim))
    top_theta = np.reshape(top_state, [(2**site), 4, 2**(L - (site + 2))])
    bottom_theta = np.reshape(bottom_state,
                              [(2**site), 4, 2**(L - (site + 2))])

    M = np.tensordot(top_theta.conj(), bottom_theta, axes=([0, 2], [
        0, 2
    ]))  # [ ..., upper_p, ...], [..., lower_p, ...] --> upper_p, lower_p
    ## If the convention is lower_p, upper_p
    ## uncomment the following line.
    # M = M.T  # the convention is lower_p, upper_p

    ### For detailed explanation of the formula, see function var_gate
    U, _, Vd = misc.svd(M, full_matrices=False)
    new_gate = np.dot(U, Vd).conj()
    # [TODO:remove] new_gate = new_gate.reshape([2, 2, 2, 2])

    return new_gate
예제 #6
0
def get_bins_and_bincounts(samples, normalized=False):
    """take in samples, create a common set of bins, and compute the counts count(x in bin)
    for each bin and each sample x.
    Parameters
    ------------
    samples : np.array of shape (n,) or shape (k, n).
    - If shape (n,): interpreted as a set of n scalar-valued samples.
    - If shape (k, n): interpreted as k sets of n scalar-valued samples.

    Returns
    --------
    probabilities :
    bins :
    """
    nr_samples = np.prod(samples.shape)
    nr_bins = np.log2(nr_samples)
    nr_bins = int(max(nr_bins, 5))

    lims = [np.min(samples), np.max(samples)]
    bins = np.linspace(*lims, num=nr_bins)

    if samples.ndim == 2:
        out = np.asarray([
            np.histogram(x, bins=bins, density=normalized)[0] for x in samples
        ])
        return out, bins
    elif samples.ndim == 1:
        return np.histogram(samples, bins=bins, density=normalized)[0], bins
    else:
        raise ValueError(
            f"Input must have shape (n,) or shape (k,n). Instead received shape {samples.shape}"
        )
예제 #7
0
def user_rate(H, W, B):
    HHW2 = np.abs(H.conj() @ W.T) ** 2.0
    p_sig = np.diag(HHW2)
    p_int = HHW2 - np.diag(p_sig)
    SINR = p_sig / (p_int.sum(axis=1) + 1.0)
    rate = B * np.log2(1 + SINR)
    return rate
예제 #8
0
 def partition_entropy_coefficient(self):
     if hasattr(self, 'u'):
         return -np.sum(self.u * np.log2(self.u)) / self.n_samples
     else:
         raise ReferenceError(
             "You need to train the model first. You can use `.fit()` method to this."
         )
예제 #9
0
def transition_rate_expm(matrix, target_diagonal=1e-3, renormalize_cols=True):
    """Slightly improved expm for transition rate matrices.

  A transition rate matrix will always have columns that sum to zero, and will
  have nonnegative entries everywhere except the diagonal. We can ensure some
  stability by controlling the magnitude of the diagonal elements and
  renormalizing during each squaring to reduce error.

  Args:
    matrix: The matrix to compute a matrix exponential for.
    target_diagonal: Maximum magnitude of the diagonal elements for which it is
      "safe" to approximate e(tA) as I + tA. Will automatically perform more
      iterations until this is small enough to be a good approximation.
    renormalize_cols: Whether to renormalize the columns of the result, with the
      assumption that the rate matrix summed to zero across the columns. This
      property should always hold, so renormalizing can prevent errors from
      exploding.

  Returns:
    Approximation of expm(matrix).
  """
    max_diag = jnp.max(-jnp.diag(matrix))
    # Each iteration halves the diagonal. How many do we need to get to at or
    # below the target diagonal?
    iterations_for_diagonal = jnp.ceil(
        jnp.log2(max_diag) - jnp.log2(target_diagonal))
    # Make sure we're also squaring enough so that every element has a chance of
    # transitioning to every other element, in theory.
    iterations_for_mixing = jnp.ceil(jnp.log2(matrix.shape[0]))
    iterations = jnp.maximum(iterations_for_diagonal,
                             iterations_for_mixing).astype(jnp.int32)

    # Locally linear approximation: e^A ~= I + A
    # First divide by 2^iterations so that this approximation is accurate.
    tiny_approx = jnp.eye(matrix.shape[0]) + matrix / (2.0**iterations)

    def step(i, mat):
        del i

        updated = jnp.dot(mat, mat, precision=jax.lax.Precision.HIGHEST)
        if renormalize_cols:
            updated = updated / jnp.sum(updated, axis=0, keepdims=True)
        return updated

    result = jax.lax.fori_loop(0, iterations, step, tiny_approx)
    return result
예제 #10
0
파일: jaxent.py 프로젝트: adamhaber/JaxEnt
 def _calc_entropy(self,p):
     """calc the entropy of a probability 
     
     Parameters
     ----------
     p : array_like
         vector of probabilities
     """
     return -np.sum(p*np.log2(p))
예제 #11
0
    def _dp(self, log_potentials, length):
        semiring = self.semiring
        N, C, C2 = log_potentials.shape
        assert C == C2, "Transition shape doesn't match"
        log_N = np.array(np.log2(N), int)
        #assert log_N % 1 == 0.0

        extra = np.where(np.eye(C, C), semiring.one, semiring.zero)
        chart = np.where(
            np.arange(N).reshape(N, 1, 1) < length - 1, log_potentials, extra)
        for _ in range(log_N):
            chart = semiring.matmul(chart[1::2], chart[0::2])
        return semiring.sum(semiring.sum(chart[0]))
예제 #12
0
    def _inverted(target):
        init = (jnp.ones_like(target) * left_bound,
                jnp.ones_like(target) * right_bound)
        n_iters = jnp.ceil(-jnp.log2(eps)).astype(int)

        def _body(_, left_right):
            left_bound, right_bound = left_right
            cand = (left_bound + right_bound) / 2
            pred = bijector(cand)
            left_bound = jnp.where(pred < target, cand, left_bound)
            right_bound = jnp.where(pred > target, cand, right_bound)
            return left_bound, right_bound

        return jax.lax.fori_loop(0, n_iters, _body, init)[0]
예제 #13
0
파일: jaxent.py 프로젝트: adamhaber/JaxEnt
        def comp_Z(histogram,densities,energy_bins):
            energy_bins = energy_bins[densities > 0]
            densities = densities[densities > 0]
            densities = densities - densities.min()

            neg_e = energy_bins<0

            logZ = logsumexp(-energy_bins + self.N*np.log(2) + densities - logsumexp(densities))
            Z = np.exp(logZ)

            # separate to negative and positive energy bins and do the same trick?
            logS1 = logsumexp(-energy_bins[neg_e] + np.log(-energy_bins[neg_e]) + self.N*np.log(2) + densities[neg_e] - logsumexp(densities))
            logS2 = logsumexp(-energy_bins[~neg_e] + np.log(energy_bins[~neg_e]) + self.N*np.log(2) + densities[~neg_e] - logsumexp(densities))
            entropy = (np.exp(logS2)-np.exp(logS1))/(np.log(2) * Z) + np.log2(Z) 
            return Z, entropy
예제 #14
0
def run(log_potentials, length, semiring="Log"):
    "Main code, associative forward-backward"
    if semiring == "Log":
        semiring = LogSemiring
    else:
        semiring = MaxSemiring

    N, C, C2 = log_potentials.shape
    assert C == C2, "Transition shape doesn't match"
    log_N = np.log2(N)
    assert log_N % 1 == 0.0

    extra = np.where(np.eye(C, C), semiring.one, semiring.zero)
    chart = np.where(np.arange(N).reshape(N, 1, 1) < length - 1, log_potentials, extra)    
    for _ in range(int(log_N)):
        chart = semiring.matmul(chart[1::2], chart[0::2])
    return semiring.sum(semiring.sum(chart[0]))
예제 #15
0
  def likelihood_fn(prng, pstate, data):
    """Compute an unbiased estimate to the log-likelihood in bits/dim.

    Args:
      prng: An array of random states. The list dimension equals the number of devices.
      pstate: Replicated training state for running on multiple devices.
      data: A JAX array of shape [#devices, batch size, ...].

    Returns:
      bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim.
      z: A JAX array of the same shape as `data`. The latent representation of `data` under the
        probability flow ODE.
      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
    """
    rng, step_rng = jax.random.split(flax.jax_utils.unreplicate(prng))
    shape = data.shape
    if hutchinson_type == 'Gaussian':
      epsilon = jax.random.normal(step_rng, shape)
    elif hutchinson_type == 'Rademacher':
      epsilon = jax.random.randint(step_rng, shape,
                                   minval=0, maxval=2).astype(jnp.float32) * 2 - 1
    else:
      raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

    def ode_func(t, x):
      sample = mutils.from_flattened_numpy(x[:-shape[0] * shape[1]], shape)
      vec_t = jnp.ones((sample.shape[0], sample.shape[1])) * t
      drift = mutils.to_flattened_numpy(p_drift_fn(pstate, sample, vec_t))
      logp_grad = mutils.to_flattened_numpy(p_div_fn(pstate, sample, vec_t, epsilon))
      return np.concatenate([drift, logp_grad], axis=0)

    init = jnp.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0] * shape[1],))], axis=0)
    solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
    nfe = solution.nfev
    zp = jnp.asarray(solution.y[:, -1])
    z = mutils.from_flattened_numpy(zp[:-shape[0] * shape[1]], shape)
    delta_logp = zp[-shape[0] * shape[1]:].reshape((shape[0], shape[1]))
    prior_logp = p_prior_logp_fn(z)
    bpd = -(prior_logp + delta_logp) / np.log(2)
    N = np.prod(shape[2:])
    bpd = bpd / N
    # A hack to convert log-likelihoods to bits/dim
    # based on the gradient of the inverse data normalizer.
    offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8.
    bpd += offset
    return bpd, z, nfe
예제 #16
0
    def mat_power(mat_m, p):
        """Computes mat_m^p, for p == 1, 2, 4 or 8.

    Args:
      mat_m: a square matrix
      p: a positive integer

    Returns:
      mat_m^p
    """
        # We unrolled the loop for performance reasons.
        exponent = jnp.round(jnp.log2(p))
        return lax.switch(jnp.asarray(exponent, jnp.int32), [
            _unrolled_mat_pow_1,
            _unrolled_mat_pow_2,
            _unrolled_mat_pow_4,
            _unrolled_mat_pow_8,
        ], (mat_m))
예제 #17
0
def apply_gate_exact(state, gate, idx):
    '''
    Goal:
        Apply gate on the state vector
        assuming local dimension d=2
    Input:
        state: a vector
        gate: the gate to apply
        idx: the gate is applying on (idx, idx+1)

    Return:
        state
    '''
    total_dim = state.size
    L = np.log2(total_dim).astype(int)
    theta = np.reshape(state, [(2**idx), 4, 2**(L - (idx + 2))])
    theta = np.tensordot(gate, theta,
                         [1, 1])  ## [ij] [..., j, ...] --> [i, ..., ...]
    state = (np.transpose(theta, [1, 0, 2])).flatten()
    return state
예제 #18
0
def two_normalize(tensor, axis=None):
    """
    Reduce the norm of tensor to near one by rescaling using power of two

    Args:
        tensor:     The tensor we wish to two-normalize. When axis is 
                    specified, the remaining axes are treated as batch dims
        axis:       Int or tuple of ints specifying the axes in which 
                    two-normalization occurs. When axis=None, the entire 
                    tensor is two-normalized
    
    Returns:
        out_tensor: Same as tensor, but where appropriate norms have been
                    two-normalized to be between 1 and 2
        two_pow:    The power of two by which out_tensor was rescaled
    """
    two_pow = jnp.floor(
        jnp.log2(jnp.linalg.norm(tensor, axis=axis, keepdims=True)))
    tensor = tensor / 2**two_pow

    return tensor, jnp.squeeze(two_pow, axis=axis)
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
예제 #20
0
def histogram_entropy(data, nbins: int = 10):
    """Calculates the histogram entropy of 1D data.
    This function uses the histogram and then calculates
    the entropy. Does the miller-maddow correction
    
    Parameters
    ----------
    data : np.ndarray, (n_samples,)
        the input data for the entropy
    
    base : int, default=2
        the log base for the calculation.
    
    Returns
    -------
    S : float
        the entropy"""

    # get histogram counts and bin edges
    counts, bin_edges = np.histogram(data, bins=nbins, density=False)

    # get bin centers and sizes
    bin_centers = np.mean(np.vstack((bin_edges[0:-1], bin_edges[1:])), axis=0)

    # get difference between the bins
    delta = bin_centers[3] - bin_centers[2]

    # normalize counts (density)
    pk = 1.0 * np.array(counts) / np.sum(counts)

    # calculate the entropy
    S = univariate_entropy(pk)

    # Miller Maddow Correction
    correction = 0.5 * (np.sum(counts > 0) - 1) / counts.sum()

    return S + correction + np.log2(delta)
예제 #21
0
    def create_symmetric_fp(
        cls,
        *,
        bounds,
        fp_quant,
    ):
        """Create QuantOps for symmetric clipping to floating-point bounds.

    Args:
      bounds: The upper (and absolute lower) bound to clip the inputs.
      fp_quant: quantization floating-point specification of the target format.

    Returns:
      QuantOps for quantizing/dequantizing signed activations.
    """
        if bounds is None:
            if fp_quant.is_scaled:
                raise ValueError(
                    'bounds can only be None if fp_quant.is_scaled is False.')
            return cls(prec=fp_quant, scale=None, symmetric=True, bounds=None)
        else:
            initial_bounds = bounds
            bounds = jnp.asarray(bounds, SCALE_DTYPE)
            if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
                bounds += jnp.finfo(SCALE_DTYPE).eps  # to avoid log2(0)
            scale = jnp.exp2(
                -jnp.floor(jnp.log2(bounds)))  # Scale to unit binade.
            # NOTE: stop_gradient is needed here to prevent gradient flow through
            # scale when scale is not a constant, but computed as a function of
            # activations or weights.
            scale = lax.stop_gradient(scale)

            return cls(prec=fp_quant,
                       scale=scale,
                       symmetric=True,
                       bounds=initial_bounds)
예제 #22
0
 def row(i):
     return jnp.ceil(jnp.log2(i + 1))
예제 #23
0
def log2(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.log2(x))
예제 #24
0
def dwt_max_level(data_len: int, filt_len: int) -> int:
    return np.floor(np.log2(data_len / (filt_len - 1.))).astype(np.int32)
예제 #25
0
 def loss(images):
     images = center(images)
     losses = -(logprob_from_conditional_params(images, *pixel_cnn(images))
                * jnp.log2(jnp.e) / images[0].size)
     assert losses.shape == (images.shape[0], )
     return jnp.mean(losses)
예제 #26
0
파일: jax.py 프로젝트: yibit/eagerpy
 def log2(self: TensorType) -> TensorType:
     return type(self)(np.log2(self.raw))
예제 #27
0
def _hist_bits(v, uniq):
  """Number of bits required to encode the histogram of v."""
  d = v.size
  k = uniq.size
  return k * jnp.log2(jnp.exp(1)*(d + k)/k)
예제 #28
0
 def unbatched_loss(rng, image):
     image = centre(image)
     pcnn_out = pixel_cnn(rng, image)
     conditional_params = pcnn_out_to_conditional_params(image, pcnn_out)
     return -(conditional_params_to_logprob(image, conditional_params) *
              np.log2(np.e) / image.size)
예제 #29
0
def _entropy(v, uniq):
  uniq = jnp.concatenate([uniq, jnp.array([jnp.inf])], axis=0)
  hist, _ = jnp.histogram(v, bins=uniq)
  hist = hist / jnp.sum(hist)
  entropy = -jnp.sum(hist * jnp.log2(hist))
  return entropy