コード例 #1
0
def compute_information_removal_samples_by_squaring(rate_matrix,
                                                    initial_distribution,
                                                    min_exponent=1e-4,
                                                    max_exponent=1e5,
                                                    interpolation_steps=256,
                                                    use_perplexity=False):
    """Compute mutual information using repeated squaring.

  Reduces a bunch of repeated work by evaluating power-of-two exponents using
  repeated squaring, starting from a few different test offsets to fill the
  gaps between powers of two.

  Args:
    rate_matrix: Transition rate matrix of shape [vocab_size, vocab_size]
    initial_distribution: Initial distribution of tokens.
    min_exponent: Smallest non-zero exponent to try.
    max_exponent: Largest exponent to try.
    interpolation_steps: Minimum number of interpolation steps to try.
    use_perplexity: Use conditional perplexity(ish) instead of MI

  Returns:
    exponents: Array of exponents for which we computed relative mutual
      information removal.
    information_removals: Array of the information removal for each exponent.
  """
    # How many powers of two do we need to fill the range?
    powers_of_two = 1 + jnp.ceil(
        jnp.log2(max_exponent) - jnp.log2(min_exponent)).astype(jnp.int32)
    # How many shifts should we evaluate between each power of two? For instance,
    # in addition to evaluating at 1, 2, 4, 8, 16, 32 we might also evaluate at
    # 3/2, 3, 6, 12, 24, 48. Increasing interpolation steps will increase this.
    shifts = jnp.ceil(interpolation_steps / powers_of_two).astype(jnp.int32)

    # Figure out the base exponents (1 and 3/2 in the above example, but there
    # may be more)
    base_exponents = jnp.exp2(
        jnp.log2(min_exponent) + jnp.linspace(0, 1, shifts, endpoint=False))

    def from_base(base_exponent):
        base_matrix = transition_rate_expm(base_exponent * rate_matrix)

        def step(mat, i):
            exponent = base_exponent * (2.0**i)
            info_removal = compute_relative_information_removal(
                mat, initial_distribution, use_perplexity=use_perplexity)
            new_mat = jnp.dot(mat, mat, precision=jax.lax.Precision.HIGHEST)
            new_mat = new_mat / jnp.sum(new_mat, axis=0, keepdims=True)
            return new_mat, (exponent, info_removal)

        _, (exponents,
            info_removals) = jax.lax.scan(step,
                                          init=base_matrix,
                                          xs=jnp.arange(powers_of_two))
        return exponents, info_removals

    exponents, info_removals = jax.lax.map(from_base, base_exponents)
    return exponents.reshape([-1]), info_removals.reshape([-1])
コード例 #2
0
ファイル: semimarkov.py プロジェクト: justinchiu/strux
    def _dp(self, log_potentials, length):
        "Compute forward pass by linear scan"
        semiring = self.semiring
        N, K, C, C2 = log_potentials.shape
        assert C == C2, "Transition shape doesn't match"
        log_N = np.log2(N)
        assert log_N % 1 == 0.0

        # Init.
        init = np.full((N, K-1, K-1, C, C), semiring.zero)
        Cs = np.arange(C)

        init = init.at[:, 0, 0, Cs, Cs].set(semiring.one)

        mask = np.arange(N).reshape(N, 1, 1, 1) < length - 1
        log_potentials = np.where(mask, log_potentials, semiring.zero)
        init = init.at[:, 0].set(np.where(mask, semiring.zero, init[:, 0]))
        start = semiring.sum(np.stack([init[:, :K-1, 0], log_potentials[:, 1:K]], axis=-1))         
        init = init.at[:, :K-1, 0].set(start)        
        end = length - 1
        for k in range(1, K - 1):
            mask = np.arange(N).reshape(N, 1) < end - (k - 1)
            v = np.where(mask, semiring.one, init[:, k - 1, k, Cs, Cs])
            init = init.at[:, k - 1, k, Cs, Cs].set(v)
        
        K_1 = K - 1
        chart = (
            init.transpose((0, 1, 3, 2, 4))
            .reshape(N, K_1 * C, K_1 * C)
        )

        for n in range(int(log_N)):
            chart = semiring.matmul(chart[1::2], chart[0::2])
        chart = chart.reshape(K_1, C, K_1, C)        
        return semiring.sum(semiring.sum(chart[0, :, 0, :]))
コード例 #3
0
    def two_normalize(m):
        # Divide m by a power of 2 to get its norm close to 1
        norm = np.linalg.norm(m, axis=(2, 3), keepdims=True)
        two_pow = np.floor(np.log2(norm))
        stable_m = m / (2**two_pow)

        return stable_m, np.sum(two_pow, axis=0)
コード例 #4
0
 def loss(batch):
     theta = wavenet(batch)[:, :-1, :]
     # now slice the padding off the batch
     sliced_batch = batch[:, receptive_field:, :]
     return (np.mean(discretized_mix_logistic_loss(
         theta, sliced_batch, num_class=1 << 16),
                     axis=0) * np.log2(np.e) / (output_width - 1))
コード例 #5
0
def var_gate_exact(top_state, site, bottom_state):
    '''
    Goal:
        to find argmax_{gate} <top_state | gate | down_state>
        where gate is actting on (site, site+1)
    Input:
        top_state: (did not have conjugation yet!!!)
        site: gate is applying on (site, site+1)
        bottom_state
    Return:
        new_gate
    '''
    total_dim = top_state.size
    L = int(np.log2(total_dim))
    top_theta = np.reshape(top_state, [(2**site), 4, 2**(L - (site + 2))])
    bottom_theta = np.reshape(bottom_state,
                              [(2**site), 4, 2**(L - (site + 2))])

    M = np.tensordot(top_theta.conj(), bottom_theta, axes=([0, 2], [
        0, 2
    ]))  # [ ..., upper_p, ...], [..., lower_p, ...] --> upper_p, lower_p
    ## If the convention is lower_p, upper_p
    ## uncomment the following line.
    # M = M.T  # the convention is lower_p, upper_p

    ### For detailed explanation of the formula, see function var_gate
    U, _, Vd = misc.svd(M, full_matrices=False)
    new_gate = np.dot(U, Vd).conj()
    # [TODO:remove] new_gate = new_gate.reshape([2, 2, 2, 2])

    return new_gate
コード例 #6
0
def get_bins_and_bincounts(samples, normalized=False):
    """take in samples, create a common set of bins, and compute the counts count(x in bin)
    for each bin and each sample x.
    Parameters
    ------------
    samples : np.array of shape (n,) or shape (k, n).
    - If shape (n,): interpreted as a set of n scalar-valued samples.
    - If shape (k, n): interpreted as k sets of n scalar-valued samples.

    Returns
    --------
    probabilities :
    bins :
    """
    nr_samples = np.prod(samples.shape)
    nr_bins = np.log2(nr_samples)
    nr_bins = int(max(nr_bins, 5))

    lims = [np.min(samples), np.max(samples)]
    bins = np.linspace(*lims, num=nr_bins)

    if samples.ndim == 2:
        out = np.asarray([
            np.histogram(x, bins=bins, density=normalized)[0] for x in samples
        ])
        return out, bins
    elif samples.ndim == 1:
        return np.histogram(samples, bins=bins, density=normalized)[0], bins
    else:
        raise ValueError(
            f"Input must have shape (n,) or shape (k,n). Instead received shape {samples.shape}"
        )
コード例 #7
0
ファイル: local_ee.py プロジェクト: szhang104/pycomm
def user_rate(H, W, B):
    HHW2 = np.abs(H.conj() @ W.T) ** 2.0
    p_sig = np.diag(HHW2)
    p_int = HHW2 - np.diag(p_sig)
    SINR = p_sig / (p_int.sum(axis=1) + 1.0)
    rate = B * np.log2(1 + SINR)
    return rate
コード例 #8
0
 def partition_entropy_coefficient(self):
     if hasattr(self, 'u'):
         return -np.sum(self.u * np.log2(self.u)) / self.n_samples
     else:
         raise ReferenceError(
             "You need to train the model first. You can use `.fit()` method to this."
         )
コード例 #9
0
def transition_rate_expm(matrix, target_diagonal=1e-3, renormalize_cols=True):
    """Slightly improved expm for transition rate matrices.

  A transition rate matrix will always have columns that sum to zero, and will
  have nonnegative entries everywhere except the diagonal. We can ensure some
  stability by controlling the magnitude of the diagonal elements and
  renormalizing during each squaring to reduce error.

  Args:
    matrix: The matrix to compute a matrix exponential for.
    target_diagonal: Maximum magnitude of the diagonal elements for which it is
      "safe" to approximate e(tA) as I + tA. Will automatically perform more
      iterations until this is small enough to be a good approximation.
    renormalize_cols: Whether to renormalize the columns of the result, with the
      assumption that the rate matrix summed to zero across the columns. This
      property should always hold, so renormalizing can prevent errors from
      exploding.

  Returns:
    Approximation of expm(matrix).
  """
    max_diag = jnp.max(-jnp.diag(matrix))
    # Each iteration halves the diagonal. How many do we need to get to at or
    # below the target diagonal?
    iterations_for_diagonal = jnp.ceil(
        jnp.log2(max_diag) - jnp.log2(target_diagonal))
    # Make sure we're also squaring enough so that every element has a chance of
    # transitioning to every other element, in theory.
    iterations_for_mixing = jnp.ceil(jnp.log2(matrix.shape[0]))
    iterations = jnp.maximum(iterations_for_diagonal,
                             iterations_for_mixing).astype(jnp.int32)

    # Locally linear approximation: e^A ~= I + A
    # First divide by 2^iterations so that this approximation is accurate.
    tiny_approx = jnp.eye(matrix.shape[0]) + matrix / (2.0**iterations)

    def step(i, mat):
        del i

        updated = jnp.dot(mat, mat, precision=jax.lax.Precision.HIGHEST)
        if renormalize_cols:
            updated = updated / jnp.sum(updated, axis=0, keepdims=True)
        return updated

    result = jax.lax.fori_loop(0, iterations, step, tiny_approx)
    return result
コード例 #10
0
ファイル: jaxent.py プロジェクト: adamhaber/JaxEnt
 def _calc_entropy(self,p):
     """calc the entropy of a probability 
     
     Parameters
     ----------
     p : array_like
         vector of probabilities
     """
     return -np.sum(p*np.log2(p))
コード例 #11
0
ファイル: linearchain.py プロジェクト: justinchiu/strux
    def _dp(self, log_potentials, length):
        semiring = self.semiring
        N, C, C2 = log_potentials.shape
        assert C == C2, "Transition shape doesn't match"
        log_N = np.array(np.log2(N), int)
        #assert log_N % 1 == 0.0

        extra = np.where(np.eye(C, C), semiring.one, semiring.zero)
        chart = np.where(
            np.arange(N).reshape(N, 1, 1) < length - 1, log_potentials, extra)
        for _ in range(log_N):
            chart = semiring.matmul(chart[1::2], chart[0::2])
        return semiring.sum(semiring.sum(chart[0]))
コード例 #12
0
ファイル: jax_bridge.py プロジェクト: noegroup/bgflow
    def _inverted(target):
        init = (jnp.ones_like(target) * left_bound,
                jnp.ones_like(target) * right_bound)
        n_iters = jnp.ceil(-jnp.log2(eps)).astype(int)

        def _body(_, left_right):
            left_bound, right_bound = left_right
            cand = (left_bound + right_bound) / 2
            pred = bijector(cand)
            left_bound = jnp.where(pred < target, cand, left_bound)
            right_bound = jnp.where(pred > target, cand, right_bound)
            return left_bound, right_bound

        return jax.lax.fori_loop(0, n_iters, _body, init)[0]
コード例 #13
0
ファイル: jaxent.py プロジェクト: adamhaber/JaxEnt
        def comp_Z(histogram,densities,energy_bins):
            energy_bins = energy_bins[densities > 0]
            densities = densities[densities > 0]
            densities = densities - densities.min()

            neg_e = energy_bins<0

            logZ = logsumexp(-energy_bins + self.N*np.log(2) + densities - logsumexp(densities))
            Z = np.exp(logZ)

            # separate to negative and positive energy bins and do the same trick?
            logS1 = logsumexp(-energy_bins[neg_e] + np.log(-energy_bins[neg_e]) + self.N*np.log(2) + densities[neg_e] - logsumexp(densities))
            logS2 = logsumexp(-energy_bins[~neg_e] + np.log(energy_bins[~neg_e]) + self.N*np.log(2) + densities[~neg_e] - logsumexp(densities))
            entropy = (np.exp(logS2)-np.exp(logS1))/(np.log(2) * Z) + np.log2(Z) 
            return Z, entropy
コード例 #14
0
ファイル: benchmark_fb.py プロジェクト: justinchiu/strux
def run(log_potentials, length, semiring="Log"):
    "Main code, associative forward-backward"
    if semiring == "Log":
        semiring = LogSemiring
    else:
        semiring = MaxSemiring

    N, C, C2 = log_potentials.shape
    assert C == C2, "Transition shape doesn't match"
    log_N = np.log2(N)
    assert log_N % 1 == 0.0

    extra = np.where(np.eye(C, C), semiring.one, semiring.zero)
    chart = np.where(np.arange(N).reshape(N, 1, 1) < length - 1, log_potentials, extra)    
    for _ in range(int(log_N)):
        chart = semiring.matmul(chart[1::2], chart[0::2])
    return semiring.sum(semiring.sum(chart[0]))
コード例 #15
0
  def likelihood_fn(prng, pstate, data):
    """Compute an unbiased estimate to the log-likelihood in bits/dim.

    Args:
      prng: An array of random states. The list dimension equals the number of devices.
      pstate: Replicated training state for running on multiple devices.
      data: A JAX array of shape [#devices, batch size, ...].

    Returns:
      bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim.
      z: A JAX array of the same shape as `data`. The latent representation of `data` under the
        probability flow ODE.
      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
    """
    rng, step_rng = jax.random.split(flax.jax_utils.unreplicate(prng))
    shape = data.shape
    if hutchinson_type == 'Gaussian':
      epsilon = jax.random.normal(step_rng, shape)
    elif hutchinson_type == 'Rademacher':
      epsilon = jax.random.randint(step_rng, shape,
                                   minval=0, maxval=2).astype(jnp.float32) * 2 - 1
    else:
      raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

    def ode_func(t, x):
      sample = mutils.from_flattened_numpy(x[:-shape[0] * shape[1]], shape)
      vec_t = jnp.ones((sample.shape[0], sample.shape[1])) * t
      drift = mutils.to_flattened_numpy(p_drift_fn(pstate, sample, vec_t))
      logp_grad = mutils.to_flattened_numpy(p_div_fn(pstate, sample, vec_t, epsilon))
      return np.concatenate([drift, logp_grad], axis=0)

    init = jnp.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0] * shape[1],))], axis=0)
    solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
    nfe = solution.nfev
    zp = jnp.asarray(solution.y[:, -1])
    z = mutils.from_flattened_numpy(zp[:-shape[0] * shape[1]], shape)
    delta_logp = zp[-shape[0] * shape[1]:].reshape((shape[0], shape[1]))
    prior_logp = p_prior_logp_fn(z)
    bpd = -(prior_logp + delta_logp) / np.log(2)
    N = np.prod(shape[2:])
    bpd = bpd / N
    # A hack to convert log-likelihoods to bits/dim
    # based on the gradient of the inverse data normalizer.
    offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8.
    bpd += offset
    return bpd, z, nfe
コード例 #16
0
ファイル: linear_algebra.py プロジェクト: ksachdeva/optax
    def mat_power(mat_m, p):
        """Computes mat_m^p, for p == 1, 2, 4 or 8.

    Args:
      mat_m: a square matrix
      p: a positive integer

    Returns:
      mat_m^p
    """
        # We unrolled the loop for performance reasons.
        exponent = jnp.round(jnp.log2(p))
        return lax.switch(jnp.asarray(exponent, jnp.int32), [
            _unrolled_mat_pow_1,
            _unrolled_mat_pow_2,
            _unrolled_mat_pow_4,
            _unrolled_mat_pow_8,
        ], (mat_m))
コード例 #17
0
def apply_gate_exact(state, gate, idx):
    '''
    Goal:
        Apply gate on the state vector
        assuming local dimension d=2
    Input:
        state: a vector
        gate: the gate to apply
        idx: the gate is applying on (idx, idx+1)

    Return:
        state
    '''
    total_dim = state.size
    L = np.log2(total_dim).astype(int)
    theta = np.reshape(state, [(2**idx), 4, 2**(L - (idx + 2))])
    theta = np.tensordot(gate, theta,
                         [1, 1])  ## [ij] [..., j, ...] --> [i, ..., ...]
    state = (np.transpose(theta, [1, 0, 2])).flatten()
    return state
コード例 #18
0
ファイル: utils.py プロジェクト: jemisjoky/umps_code
def two_normalize(tensor, axis=None):
    """
    Reduce the norm of tensor to near one by rescaling using power of two

    Args:
        tensor:     The tensor we wish to two-normalize. When axis is 
                    specified, the remaining axes are treated as batch dims
        axis:       Int or tuple of ints specifying the axes in which 
                    two-normalization occurs. When axis=None, the entire 
                    tensor is two-normalized
    
    Returns:
        out_tensor: Same as tensor, but where appropriate norms have been
                    two-normalized to be between 1 and 2
        two_pow:    The power of two by which out_tensor was rescaled
    """
    two_pow = jnp.floor(
        jnp.log2(jnp.linalg.norm(tensor, axis=axis, keepdims=True)))
    tensor = tensor / 2**two_pow

    return tensor, jnp.squeeze(two_pow, axis=axis)
コード例 #19
0
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
コード例 #20
0
def histogram_entropy(data, nbins: int = 10):
    """Calculates the histogram entropy of 1D data.
    This function uses the histogram and then calculates
    the entropy. Does the miller-maddow correction
    
    Parameters
    ----------
    data : np.ndarray, (n_samples,)
        the input data for the entropy
    
    base : int, default=2
        the log base for the calculation.
    
    Returns
    -------
    S : float
        the entropy"""

    # get histogram counts and bin edges
    counts, bin_edges = np.histogram(data, bins=nbins, density=False)

    # get bin centers and sizes
    bin_centers = np.mean(np.vstack((bin_edges[0:-1], bin_edges[1:])), axis=0)

    # get difference between the bins
    delta = bin_centers[3] - bin_centers[2]

    # normalize counts (density)
    pk = 1.0 * np.array(counts) / np.sum(counts)

    # calculate the entropy
    S = univariate_entropy(pk)

    # Miller Maddow Correction
    correction = 0.5 * (np.sum(counts > 0) - 1) / counts.sum()

    return S + correction + np.log2(delta)
コード例 #21
0
    def create_symmetric_fp(
        cls,
        *,
        bounds,
        fp_quant,
    ):
        """Create QuantOps for symmetric clipping to floating-point bounds.

    Args:
      bounds: The upper (and absolute lower) bound to clip the inputs.
      fp_quant: quantization floating-point specification of the target format.

    Returns:
      QuantOps for quantizing/dequantizing signed activations.
    """
        if bounds is None:
            if fp_quant.is_scaled:
                raise ValueError(
                    'bounds can only be None if fp_quant.is_scaled is False.')
            return cls(prec=fp_quant, scale=None, symmetric=True, bounds=None)
        else:
            initial_bounds = bounds
            bounds = jnp.asarray(bounds, SCALE_DTYPE)
            if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
                bounds += jnp.finfo(SCALE_DTYPE).eps  # to avoid log2(0)
            scale = jnp.exp2(
                -jnp.floor(jnp.log2(bounds)))  # Scale to unit binade.
            # NOTE: stop_gradient is needed here to prevent gradient flow through
            # scale when scale is not a constant, but computed as a function of
            # activations or weights.
            scale = lax.stop_gradient(scale)

            return cls(prec=fp_quant,
                       scale=scale,
                       symmetric=True,
                       bounds=initial_bounds)
コード例 #22
0
 def row(i):
     return jnp.ceil(jnp.log2(i + 1))
コード例 #23
0
def log2(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.log2(x))
コード例 #24
0
def dwt_max_level(data_len: int, filt_len: int) -> int:
    return np.floor(np.log2(data_len / (filt_len - 1.))).astype(np.int32)
コード例 #25
0
ファイル: pixelcnn.py プロジェクト: juliuskunze/jaxnet
 def loss(images):
     images = center(images)
     losses = -(logprob_from_conditional_params(images, *pixel_cnn(images))
                * jnp.log2(jnp.e) / images[0].size)
     assert losses.shape == (images.shape[0], )
     return jnp.mean(losses)
コード例 #26
0
ファイル: jax.py プロジェクト: yibit/eagerpy
 def log2(self: TensorType) -> TensorType:
     return type(self)(np.log2(self.raw))
コード例 #27
0
def _hist_bits(v, uniq):
  """Number of bits required to encode the histogram of v."""
  d = v.size
  k = uniq.size
  return k * jnp.log2(jnp.exp(1)*(d + k)/k)
コード例 #28
0
ファイル: pixelcnn.py プロジェクト: yueyedeai/jaxnet
 def unbatched_loss(rng, image):
     image = centre(image)
     pcnn_out = pixel_cnn(rng, image)
     conditional_params = pcnn_out_to_conditional_params(image, pcnn_out)
     return -(conditional_params_to_logprob(image, conditional_params) *
              np.log2(np.e) / image.size)
コード例 #29
0
def _entropy(v, uniq):
  uniq = jnp.concatenate([uniq, jnp.array([jnp.inf])], axis=0)
  hist, _ = jnp.histogram(v, bins=uniq)
  hist = hist / jnp.sum(hist)
  entropy = -jnp.sum(hist * jnp.log2(hist))
  return entropy