Ejemplo n.º 1
0
 def cond_1():
     sq = jnp.sqrt(1.0 + m00 - m11 - m22 + eps) * 2.  # sq = 4 * x.
     w = jnp.divide(m21 - m12, sq)
     x = 0.25 * sq
     y = jnp.divide(m01 + m10, sq)
     z = jnp.divide(m02 + m20, sq)
     return jnp.stack((x, y, z, w), axis=-1)
Ejemplo n.º 2
0
def decoupled_multivariate_normal_kl_divergence(
        mu_0: Array,
        sigma_0: Numeric,
        mu_1: Array,
        sigma_1: Numeric,
        per_dimension: bool = False) -> Tuple[Array, Array]:
    """Compute the KL between diagonal Gaussians decomposed into mean and covariance.

  Args:
    mu_0: array like of mean values for policy 0
    sigma_0: array like of std values for policy 0
    mu_1: array like of mean values for policy 1
    sigma_1: array like of std values for policy 1
    per_dimension: Whether to return a separate kl divergence for each dimension
      on the last axis.

  Returns:
    the kl divergence between the distributions decomposed into mean and
    covariance.
  """
    # Support scalar and vector `sigma`. If vector, mu.shape==sigma.shape.
    sigma_1 = jnp.ones_like(mu_1) * sigma_1
    sigma_0 = jnp.ones_like(mu_0) * sigma_0
    v1 = jnp.clip(sigma_1**2, 1e-6, 1e6)
    v0 = jnp.clip(sigma_0**2, 1e-6, 1e6)
    mu_diff = mu_1 - mu_0
    kl_mean = 0.5 * jnp.divide(mu_diff**2, v1)
    kl_cov = 0.5 * (jnp.divide(v0, v1) - jnp.ones_like(mu_1) + jnp.log(v1) -
                    jnp.log(v0))
    if not per_dimension:
        kl_mean = jnp.sum(kl_mean, axis=-1)
        kl_cov = jnp.sum(kl_cov, axis=-1)

    return kl_mean, kl_cov
Ejemplo n.º 3
0
 def objective(params: list, bparam: list, batch_input) -> float:
     result = 25.0
     for w1 in params:
         result += np.mean(
             np.divide(np.power(w1, 4), 4.0) +
             bparam[0] * np.divide(np.power(w1, 2), 2.0))
     return result
Ejemplo n.º 4
0
def multivariate_normal_kl_divergence(
    mu_1: ArrayLike,
    mu_0: ArrayLike,
    sigma_1: ArrayLike,
    sigma_0: ArrayLike,
) -> ArrayLike:
    """Compute the KL between two gaussian distribution with diagonal covariance matrices.

  Args:
    mu_1: array like of mean values for policy 1
    mu_0: array like of mean values for policy 0
    sigma_1: array like of std values for policy 1
    sigma_0: array like of std values for policy 0

  Returns:
    the kl divergence between the distributions.
  """
    # Support scalar and vector `sigma`. If vector, mu.shape==sigma.shape.
    sigma_1 = jnp.ones_like(mu_1) * sigma_1
    sigma_0 = jnp.ones_like(mu_0) * sigma_0
    v1 = jnp.clip(sigma_1**2, 1e-6, 1e6)
    v0 = jnp.clip(sigma_0**2, 1e-6, 1e6)
    mu = mu_1 - mu_0

    return 0.5 * (jnp.sum(jnp.divide(v0, v1)) + jnp.sum(jnp.divide(
        mu**2, v1)) - jnp.sum(jnp.ones_like(mu_1)) + jnp.sum(jnp.log(v1)) -
                  jnp.sum(jnp.log(v0)))
Ejemplo n.º 5
0
 def tr_positive():
     sq = jnp.sqrt(trace + 1.0) * 2.  # sq = 4 * w.
     w = 0.25 * sq
     x = jnp.divide(m21 - m12, sq)
     y = jnp.divide(m02 - m20, sq)
     z = jnp.divide(m10 - m01, sq)
     return jnp.stack((x, y, z, w), axis=-1)
Ejemplo n.º 6
0
def troe_falloff_correction(
    T: float, lPr: np.ndarray, troe_coeffs: np.ndarray, troe_indices: np.ndarray
) -> np.ndarray:
    """
    modify rate constants use TROE falloff parameters
    returns: np.ndarray of F(T,P) 
    """
    troe_coeffs = troe_coeffs[troe_indices]
    F_cent = (
        np.multiply(
            np.subtract(1, troe_coeffs[:, 0]), np.exp(np.divide(-T, troe_coeffs[:, 3]))
        )
        + np.multiply(troe_coeffs[:, 0], np.exp(np.divide(-T, troe_coeffs[:, 1])))
        + np.exp(np.divide(-troe_coeffs[:, 2], T))
    )
    lF_cent = np.log10(F_cent)
    C = np.subtract(-0.4, np.multiply(0.67, lF_cent))
    N = np.subtract(0.75, np.multiply(1.27, lF_cent))
    f1_numerator = lPr + C
    f1_denominator_1 = N
    f1_denominator_2 = np.multiply(0.14, f1_numerator)
    f1 = np.divide(f1_numerator, np.subtract(f1_denominator_1, f1_denominator_2))
    F = np.power(10.0, np.divide(lF_cent, (1.0 + np.square(f1))))
    # F = 10**(lF_cent / (1. + f1**2.))
    return F
Ejemplo n.º 7
0
 def cond_3():
     sq = jnp.sqrt(1.0 + m22 - m00 - m11 + eps) * 2.  # sq = 4 * z.
     w = jnp.divide(m10 - m01, sq)
     x = jnp.divide(m02 + m20, sq)
     y = jnp.divide(m12 + m21, sq)
     z = 0.25 * sq
     return jnp.stack((x, y, z, w), axis=-1)
Ejemplo n.º 8
0
 def cond_2():
     sq = jnp.sqrt(1.0 + m11 - m00 - m22 + eps) * 2.  # sq = 4 * y.
     w = jnp.divide(m02 - m20, sq)
     x = jnp.divide(m01 + m10, sq)
     y = 0.25 * sq
     z = jnp.divide(m12 + m21, sq)
     return jnp.stack((x, y, z, w), axis=-1)
Ejemplo n.º 9
0
def func(var=np.array([0.5, 1.])):
    temp = np.subtract(arg, var[1])
    temp = np.power(temp, 2)
    temp_sum = np.sum(temp)
    divider = np.power(var[0], 2)
    res = np.divide(temp_sum, divider)
    common = np.log(np.divide(1, var[0]))
    return np.dot(res, common)
Ejemplo n.º 10
0
    def backward(self, time, spike_list, weights, e_gradient):
        gamma = spike_list[0]
        t_Tk_divby_tau_m = jnp.divide(
            jnp.subtract(time, spike_list[1]), -self.tau_m)
        f_prime_t = jnp.multiply(jnp.exp(t_Tk_divby_tau_m), (-1 / self.tau_m))
        aLIFnet = jnp.multiply(
            1 / self.Vth, (1 + jnp.multiply(jnp.divide(1, gamma), f_prime_t)))
        d_w = jnp.matmul(weights, e_gradient)

        return jnp.multiply(d_w, aLIFnet)
Ejemplo n.º 11
0
def sum_f(var):
    ret = 0
    for i in range(10**4):
        # x = 10 / 10**4 * i - 5
        x = np.subtract(np.divide(10, np.dot(np.power(10, 4), i)), 5)
        # ret -= (x - var[1])**2 / var[0]**2
        ret -= np.subtract(
            ret,
            np.divide(np.power(np.subtract(x, var[1]), 2), np.power(var[0],
                                                                    2)))
    return np.dot(np.log(np.divide(1, var[0])), ret)
Ejemplo n.º 12
0
def precision(
    y_true: jnp.ndarray,
    y_pred: jnp.ndarray,
    threshold: jnp.ndarray,
    class_id: jnp.ndarray,
    sample_weight: jnp.ndarray,
    true_positives: ReduceConfusionMatrix,
    false_positives: ReduceConfusionMatrix,
) -> jnp.ndarray:

    # TODO: class_id behavior
    y_pred = (y_pred > threshold).astype(jnp.float32)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    true_positives = true_positives(y_true=y_true,
                                    y_pred=y_pred,
                                    sample_weight=sample_weight)
    false_positives = false_positives(y_true=y_true,
                                      y_pred=y_pred,
                                      sample_weight=sample_weight)

    return jnp.nan_to_num(
        jnp.divide(true_positives, true_positives + false_positives))
Ejemplo n.º 13
0
def get_reverse_rate_constants(
    kf: np.ndarray, Kc: np.ndarray, is_reversible: np.ndarray
):
    """
    calculate reverse rate constants using Kc and kf
    """
    return np.divide(kf, Kc) * is_reversible
Ejemplo n.º 14
0
def loss_one_pair0(mu_i, mu_j, s_i, s_j, d, n_components):
    s_ij = s_i + s_j + EPSILON  # try to avoid divided by zero
    # make sure d_ij is not zero
    d_ij = jnp.linalg.norm(mu_i - mu_j) + EPSILON
    nc = jnp.divide(d_ij**2, s_ij)
    factor = 2 * d / s_ij
    return -ncx2_log_pdf(x=d * d, df=n_components, nc=nc) - jnp.log(factor)
Ejemplo n.º 15
0
def get_diffusionEmbedding(points=[],
                           distance=[],
                           distmatrix=None,
                           alpha=1.0,
                           tdiff=0,
                           eps=None):
    n = len(points)
    if distmatrix is None:
        idx = jnp.array([[i, j] for i in range(n) for j in range(n)])
        d = make_distanceMatrix(points=points, idx=idx, distance=distance, n=n)
    else:
        d = distmatrix

    if eps is None:
        # using heuristic from the R package for diffusion maps
        eps = 2 * jnp.median(d)**2

    K = make_kernelMatrix(distmatrix=d, eps=eps)
    Kr = renormalize_kernel(K, alpha=alpha)
    P = make_transitionMatrix(Kr)
    u, s, v = jnp.linalg.svd(P)

    phi = u
    for i in range(len(u)):
        phi.at[:, i].set((s[i]**tdiff) * jnp.divide(u[:, i], u[:, 0]))

    return phi, s
Ejemplo n.º 16
0
def clip_eta(eta, norm, eps):
    """
    Helper function to clip the perturbation to epsilon norm ball.
    :param eta: A tensor with the current perturbation.
    :param norm: Order of the norm (mimics Numpy).
                Possible values: np.inf or 2.
    :param eps: Epsilon, bound of the perturbation.
    """

    # Clipping perturbation eta to self.norm norm ball
    if norm not in [np.inf, 2]:
        raise ValueError("norm must be np.inf or 2.")

    axis = list(range(1, len(eta.shape)))
    avoid_zero_div = 1e-12
    if norm == np.inf:
        eta = np.clip(eta, a_min=-eps, a_max=eps)
    elif norm == 2:
        # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation
        norm = np.sqrt(
            np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True))
        )
        # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball
        factor = np.minimum(1.0, np.divide(eps, norm))
        eta = eta * factor
    return eta
Ejemplo n.º 17
0
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
                       scale, translation, kernel: Callable, antialias: bool):
    inv_scale = 1. / scale
    # When downsampling the kernel should be scaled since we want to low pass
    # filter and interpolate, but when upsampling it should not be since we only
    # want to interpolate.
    kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.

    sample_f = ((jnp.arange(output_size) + 0.5) * inv_scale -
                translation * inv_scale - 0.5)
    x = (
        jnp.abs(sample_f[jnp.newaxis, :] -
                jnp.arange(input_size, dtype=sample_f.dtype)[:, jnp.newaxis]) /
        kernel_scale)
    weights = kernel(x)

    total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
    weights = jnp.where(
        jnp.abs(total_weight_sum) > 1000. * np.finfo(np.float32).eps,
        jnp.divide(weights,
                   jnp.where(total_weight_sum != 0, total_weight_sum, 1)), 0)
    # Zero out weights where the sample location is completely outside the input
    # range.
    # Note sample_f has already had the 0.5 removed, hence the weird range below.
    input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
    return jnp.where(
        jnp.logical_and(sample_f >= -0.5,
                        sample_f <= input_size_minus_0_5)[jnp.newaxis, :],
        weights, 0)
Ejemplo n.º 18
0
def func(arg):
    divider = 0.  # NOTE: I made these floats
    numerator = 0.

    def body_fun(carry, x):
        divider, numerator = carry
        temp = np.dot(x, var)
        temp1 = np.sin(temp)
        temp2 = np.cos(temp)

        divid = np.add(temp1, temp2)
        divid = np.power(divid, 2)
        divid = np.sum(divid)

        numer = np.add(temp1, temp2)
        numer = np.sum(numer)
        numer = np.power(numer, 2)
        numerator = np.add(numer, numerator)

        divider = np.add(divider, divid)

        new_carry = divider, numerator
        return new_carry, ()

    (divider, numerator), _ = lax.scan(body_fun, (divider, numerator), arg)

    divider = np.power(divider, 1 / 2)

    return np.log(np.divide(numerator, divider))
Ejemplo n.º 19
0
    def forward(self, x, v_current):
        dV_tau = jnp.multiply(jnp.subtract(x, v_current), self.dt)
        dV = jnp.divide(dV_tau, self.tau_m)
        v_current = index_add(v_current, index[:], dV)
        spike_list = jnp.greater_equal(v_current, self.Vth).astype('int32')
        v_current = jnp.where(v_current >= self.Vth, 0,
                              v_current * jnp.exp(-1 / self.tau_m))

        return spike_list, v_current
Ejemplo n.º 20
0
 def normal_logpdf(self, x, mu, sigma):
     # this is much faster than
     # norm.logpdf(x, loc=mu, scale=sigma)
     # https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions
     root2 = jnp.sqrt(2)
     root2pi = jnp.sqrt(2 * jnp.pi)
     prefactor = -jnp.log(sigma * root2pi)
     summand = -jnp.square(jnp.divide((x - mu), (root2 * sigma)))
     return prefactor + summand
Ejemplo n.º 21
0
        def np_fn(input_np, v_current, gamma, tau_m, Vth, dt):
            v_current = ((input_np - v_current) / tau_m) * dt
            spike = np.greater_equal(
                v_current + np.multiply(
                    np.divide(np.subtract(input_np, v_current), tau_m), dt),
                Vth).astype('float32')

            gamma += np.where(spike >= Vth, 1, 0)
            return spike, v_current, gamma
Ejemplo n.º 22
0
 def attention_op(self, query, key, value, mask=None):
     d_model = query.shape[-1]
     scores = jnp.divide(
         jnp.matmul(query, key.transpose(0, 2, 1)), jnp.sqrt(d_model)
     )
     if mask is not None:
         scores = jnp.matmul(scores, mask)
     attention = nn.softmax(scores, axis=-1)
     attention = jnp.matmul(attention, value)
     return attention
Ejemplo n.º 23
0
def get_K_stab(C, alpha, beta, reg):
    if isinstance(C, np.ndarray):
        # Faster exponent
        K = np.divide(C - alpha[:, None] - beta[None, :], -reg)
        return np.exp(K)
    elif isinstance(C, jax.interpreters.xla.DeviceArray):
        K = jnp.divide(C - alpha[:, None] - beta[None, :], -reg)
        return jnp.exp(K)
    else:
        raise NotImplementedError(f"The type {type(C)} is not supported!")
Ejemplo n.º 24
0
    def objective(state, bparam, batch_input):
        """
        Computes scalar objective.
        :param params: pytree PyTreeDef(list, [PyTreeDef(tuple, [*,*])])
        :param bparam: pytree *
        :param inputs: pytree *
        :param outputs: pytree *
        :return: pytree (scalar) *
        """
        result = 0.0

        for (w, b) in state:
            result += np.mean(
                np.sum(
                    np.divide(np.power(w, 4), 4.0) +
                    bparam[0] * np.divide(np.power(w, 2), 2.0)) + np.sum(
                        np.divide(np.power(b, 4), 4.0) +
                        bparam[0] * np.divide(np.power(b, 2), 2.0)))
        return result
Ejemplo n.º 25
0
    def amplitude(self, nu: ArrayLike) -> jnp.ndarray:
        """The amplitude of the glitch,
        :math:`a_\\mathrm{CZ} / \\nu^{-2}`.

        Args:
            nu (:term:`array_like`): Mode frequency, :math:`\\nu`.

        Returns:
            jax.numpy.ndarray: Base of the convective zone glitch amplitude.
        """
        return jnp.divide(self._a, nu**2)
Ejemplo n.º 26
0
def eval_for(op):
  if op.op_name in ("IAdd", "IMul", "FAdd", "FMul", "FDiv"):
    x, y = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val)
    y_bc = broadcast_dims(op.all_idxs, y.idxs, y.atom.val)
    if op.op_name in ("IAdd", "FAdd"):
      return jnp.add(x_bc, y_bc)
    elif op.op_name in ("IMul", "FMul"):
      return jnp.multiply(x_bc, y_bc)
    if op.op_name in ("FDiv",):
      return jnp.divide(x_bc, y_bc)
    else:
      raise Exception("Not implemented: " + str(op.op_name))
  elif op.op_name == "Iota":
    n, = op.size_args
    val = jnp.arange(n)
    val_bc = broadcast_dims(op.all_idxs, [], val)
    return val_bc
  elif op.op_name == "Id":
    x, = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val)
    return x_bc
  elif op.op_name == "Get":
    x, idx = op.args
    out_shape = [i.size for i in op.all_idxs]
    x_idxs_used = get_stack_idxs_used(op.all_idxs, x.idxs)
    leading_idx_arrays = []
    for i, idx_used in enumerate(x_idxs_used):
      if idx_used:
        leading_idx_arrays.append(nth_iota(out_shape, i))
      else:
        pass
    payload_idx_array = broadcast_dims(op.all_idxs, idx.idxs, idx.atom.val)
    out = x.atom.val[tuple(leading_idx_arrays) + (payload_idx_array,)]
    return out
  elif op.op_name == "IntToReal":
    x, = op.args
    real_val = jnp.array(x.atom.val, dtype="float32")
    x_bc = broadcast_dims(op.all_idxs, x.idxs, real_val)
    return x_bc
  elif op.op_name in ("FNeg", "INeg"):
    x, = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, jnp.negative(x.atom.val))
    return x_bc
  elif op.op_name == "ThreeFry2x32":
    convert_64_to_32s = lambda x: np.array([x]).view(np.uint32)
    convert_32s_to_64 = lambda x: np.int64(np.array(x).view(np.int64).item())
    x, y = op.args
    key, count = convert_64_to_32s(x.atom.val), convert_64_to_32s(y.atom.val)
    result = convert_32s_to_64(random.threefry_2x32(key, count))
    x_bc = broadcast_dims(op.all_idxs, x.idxs, result)
    return x_bc
  else:
    raise Exception("Unrecognized op: {}".format(op.op_name))
Ejemplo n.º 27
0
    def _case3(zagf):
        z, alpha, _, flag = zagf

        # Formula 59 of [1]
        z_div_a = np.divide(z, alpha)
        aa = alpha * alpha
        term1 = 1440 * alpha + 6 * z_div_a * (53 - 120 * z) - 65 * z_div_a * z_div_a + 3600 * z + 107
        term2 = 1244160 * alpha * aa
        term3 = 1 + 24 * alpha + 288 * aa
        grad = term1 * term3 / term2

        return z, alpha, grad, ~flag
Ejemplo n.º 28
0
def SUMO(x, rng, enc_params, dec_params, K, m=16, *args, **kwargs):
    # K = sampling_tail(rng)
    K_range = lax.iota(dtype=jnp.int32, size=K + m) + 1
    rngs = random.split(rng, K + m)
    vec_iw_estimator = vmap(iw_estimator, in_axes=(None, 0, None, None))
    log_iw = vec_iw_estimator(x, rngs, enc_params, dec_params)
    iwelbo_K = logcumsumexp(log_iw) - jnp.log(K_range)

    vec_reverse_cdf = vmap(reverse_cdf, in_axes=(0, ))
    inv_weights = jnp.divide(1., vec_reverse_cdf(K_range[m:]))
    return iwelbo_K[m - 1] + jnp.sum(inv_weights *
                                     (iwelbo_K[m:] - iwelbo_K[m - 1:-1]))
def round_coupling(coupling, a, b):
  """Projects a coupling matrix to the nearest matrix with marginals a and b.

  A finite number of sinkhorn iterations will always lead to a coupling P
  which does not satisfy the constraints that sum_j P[:, j] = a
  or sum_i P[i, :] = b.
  This differential rounding operation from algorithm 2 from Altschuler et al.
  ensures that you map tot he nearest matrix that does satisfy the constraints.

  Note: some implementations convert coupling, a and b to double precision
  before performing the algorithm. In case of instability, try that.

  Args:
    coupling: jnp.ndarray of shape [N, M]. Approximate coupling that results
      from for instance a Sinkhorn solver.
    a: jnp.ndarray of shape [N,]. The desired marginal of the rows of the
      coupling matrix.
    b: jnp.ndarray of shape [M,]. The desired marginal of the columns of the
      coupling matrix.

  Returns:
    r_coupling: jnp.ndarray of shape [N, M] such that
      r_coupling.sum(0) == b and r_coupling.sum(1) == a.
  """

  a_div_coupling = jnp.divide(a, coupling.sum(1))
  x = 1. - jax.nn.relu(1. - a_div_coupling)
  pp = x.reshape((-1, 1)) * coupling

  b_div_coupling = jnp.divide(b, coupling.sum(0))
  y = 1. - jax.nn.relu(1. - b_div_coupling)
  pp = pp * y.reshape((1, -1))

  err_a = a - pp.sum(1)
  err_b = b - pp.sum(0)

  kron_ab = err_a[:, jnp.newaxis] * err_b[jnp.newaxis, :]
  r_coupling = pp + kron_ab / jnp.sum(jnp.abs(err_a))

  return r_coupling
Ejemplo n.º 30
0
def _rescale(centered_oks):
    """
    compute ΔOₖ/√Sₖₖ and √Sₖₖ
    to do scale-invariant regularization (Becca & Sorella 2017, pp. 143)
    Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ)
    """
    scale = (mpi.mpi_sum_jax(
        jnp.sum((centered_oks * centered_oks.conj()).real,
                axis=0,
                keepdims=True))[0]**0.5)
    centered_oks = jnp.divide(centered_oks, scale)
    scale = jnp.squeeze(scale, axis=0)
    return centered_oks, scale