def cond_1(): sq = jnp.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * x. w = jnp.divide(m21 - m12, sq) x = 0.25 * sq y = jnp.divide(m01 + m10, sq) z = jnp.divide(m02 + m20, sq) return jnp.stack((x, y, z, w), axis=-1)
def decoupled_multivariate_normal_kl_divergence( mu_0: Array, sigma_0: Numeric, mu_1: Array, sigma_1: Numeric, per_dimension: bool = False) -> Tuple[Array, Array]: """Compute the KL between diagonal Gaussians decomposed into mean and covariance. Args: mu_0: array like of mean values for policy 0 sigma_0: array like of std values for policy 0 mu_1: array like of mean values for policy 1 sigma_1: array like of std values for policy 1 per_dimension: Whether to return a separate kl divergence for each dimension on the last axis. Returns: the kl divergence between the distributions decomposed into mean and covariance. """ # Support scalar and vector `sigma`. If vector, mu.shape==sigma.shape. sigma_1 = jnp.ones_like(mu_1) * sigma_1 sigma_0 = jnp.ones_like(mu_0) * sigma_0 v1 = jnp.clip(sigma_1**2, 1e-6, 1e6) v0 = jnp.clip(sigma_0**2, 1e-6, 1e6) mu_diff = mu_1 - mu_0 kl_mean = 0.5 * jnp.divide(mu_diff**2, v1) kl_cov = 0.5 * (jnp.divide(v0, v1) - jnp.ones_like(mu_1) + jnp.log(v1) - jnp.log(v0)) if not per_dimension: kl_mean = jnp.sum(kl_mean, axis=-1) kl_cov = jnp.sum(kl_cov, axis=-1) return kl_mean, kl_cov
def objective(params: list, bparam: list, batch_input) -> float: result = 25.0 for w1 in params: result += np.mean( np.divide(np.power(w1, 4), 4.0) + bparam[0] * np.divide(np.power(w1, 2), 2.0)) return result
def multivariate_normal_kl_divergence( mu_1: ArrayLike, mu_0: ArrayLike, sigma_1: ArrayLike, sigma_0: ArrayLike, ) -> ArrayLike: """Compute the KL between two gaussian distribution with diagonal covariance matrices. Args: mu_1: array like of mean values for policy 1 mu_0: array like of mean values for policy 0 sigma_1: array like of std values for policy 1 sigma_0: array like of std values for policy 0 Returns: the kl divergence between the distributions. """ # Support scalar and vector `sigma`. If vector, mu.shape==sigma.shape. sigma_1 = jnp.ones_like(mu_1) * sigma_1 sigma_0 = jnp.ones_like(mu_0) * sigma_0 v1 = jnp.clip(sigma_1**2, 1e-6, 1e6) v0 = jnp.clip(sigma_0**2, 1e-6, 1e6) mu = mu_1 - mu_0 return 0.5 * (jnp.sum(jnp.divide(v0, v1)) + jnp.sum(jnp.divide( mu**2, v1)) - jnp.sum(jnp.ones_like(mu_1)) + jnp.sum(jnp.log(v1)) - jnp.sum(jnp.log(v0)))
def tr_positive(): sq = jnp.sqrt(trace + 1.0) * 2. # sq = 4 * w. w = 0.25 * sq x = jnp.divide(m21 - m12, sq) y = jnp.divide(m02 - m20, sq) z = jnp.divide(m10 - m01, sq) return jnp.stack((x, y, z, w), axis=-1)
def troe_falloff_correction( T: float, lPr: np.ndarray, troe_coeffs: np.ndarray, troe_indices: np.ndarray ) -> np.ndarray: """ modify rate constants use TROE falloff parameters returns: np.ndarray of F(T,P) """ troe_coeffs = troe_coeffs[troe_indices] F_cent = ( np.multiply( np.subtract(1, troe_coeffs[:, 0]), np.exp(np.divide(-T, troe_coeffs[:, 3])) ) + np.multiply(troe_coeffs[:, 0], np.exp(np.divide(-T, troe_coeffs[:, 1]))) + np.exp(np.divide(-troe_coeffs[:, 2], T)) ) lF_cent = np.log10(F_cent) C = np.subtract(-0.4, np.multiply(0.67, lF_cent)) N = np.subtract(0.75, np.multiply(1.27, lF_cent)) f1_numerator = lPr + C f1_denominator_1 = N f1_denominator_2 = np.multiply(0.14, f1_numerator) f1 = np.divide(f1_numerator, np.subtract(f1_denominator_1, f1_denominator_2)) F = np.power(10.0, np.divide(lF_cent, (1.0 + np.square(f1)))) # F = 10**(lF_cent / (1. + f1**2.)) return F
def cond_3(): sq = jnp.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * z. w = jnp.divide(m10 - m01, sq) x = jnp.divide(m02 + m20, sq) y = jnp.divide(m12 + m21, sq) z = 0.25 * sq return jnp.stack((x, y, z, w), axis=-1)
def cond_2(): sq = jnp.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * y. w = jnp.divide(m02 - m20, sq) x = jnp.divide(m01 + m10, sq) y = 0.25 * sq z = jnp.divide(m12 + m21, sq) return jnp.stack((x, y, z, w), axis=-1)
def func(var=np.array([0.5, 1.])): temp = np.subtract(arg, var[1]) temp = np.power(temp, 2) temp_sum = np.sum(temp) divider = np.power(var[0], 2) res = np.divide(temp_sum, divider) common = np.log(np.divide(1, var[0])) return np.dot(res, common)
def backward(self, time, spike_list, weights, e_gradient): gamma = spike_list[0] t_Tk_divby_tau_m = jnp.divide( jnp.subtract(time, spike_list[1]), -self.tau_m) f_prime_t = jnp.multiply(jnp.exp(t_Tk_divby_tau_m), (-1 / self.tau_m)) aLIFnet = jnp.multiply( 1 / self.Vth, (1 + jnp.multiply(jnp.divide(1, gamma), f_prime_t))) d_w = jnp.matmul(weights, e_gradient) return jnp.multiply(d_w, aLIFnet)
def sum_f(var): ret = 0 for i in range(10**4): # x = 10 / 10**4 * i - 5 x = np.subtract(np.divide(10, np.dot(np.power(10, 4), i)), 5) # ret -= (x - var[1])**2 / var[0]**2 ret -= np.subtract( ret, np.divide(np.power(np.subtract(x, var[1]), 2), np.power(var[0], 2))) return np.dot(np.log(np.divide(1, var[0])), ret)
def precision( y_true: jnp.ndarray, y_pred: jnp.ndarray, threshold: jnp.ndarray, class_id: jnp.ndarray, sample_weight: jnp.ndarray, true_positives: ReduceConfusionMatrix, false_positives: ReduceConfusionMatrix, ) -> jnp.ndarray: # TODO: class_id behavior y_pred = (y_pred > threshold).astype(jnp.float32) if y_true.dtype != y_pred.dtype: y_pred = y_pred.astype(y_true.dtype) true_positives = true_positives(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) false_positives = false_positives(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) return jnp.nan_to_num( jnp.divide(true_positives, true_positives + false_positives))
def get_reverse_rate_constants( kf: np.ndarray, Kc: np.ndarray, is_reversible: np.ndarray ): """ calculate reverse rate constants using Kc and kf """ return np.divide(kf, Kc) * is_reversible
def loss_one_pair0(mu_i, mu_j, s_i, s_j, d, n_components): s_ij = s_i + s_j + EPSILON # try to avoid divided by zero # make sure d_ij is not zero d_ij = jnp.linalg.norm(mu_i - mu_j) + EPSILON nc = jnp.divide(d_ij**2, s_ij) factor = 2 * d / s_ij return -ncx2_log_pdf(x=d * d, df=n_components, nc=nc) - jnp.log(factor)
def get_diffusionEmbedding(points=[], distance=[], distmatrix=None, alpha=1.0, tdiff=0, eps=None): n = len(points) if distmatrix is None: idx = jnp.array([[i, j] for i in range(n) for j in range(n)]) d = make_distanceMatrix(points=points, idx=idx, distance=distance, n=n) else: d = distmatrix if eps is None: # using heuristic from the R package for diffusion maps eps = 2 * jnp.median(d)**2 K = make_kernelMatrix(distmatrix=d, eps=eps) Kr = renormalize_kernel(K, alpha=alpha) P = make_transitionMatrix(Kr) u, s, v = jnp.linalg.svd(P) phi = u for i in range(len(u)): phi.at[:, i].set((s[i]**tdiff) * jnp.divide(u[:, i], u[:, 0])) return phi, s
def clip_eta(eta, norm, eps): """ Helper function to clip the perturbation to epsilon norm ball. :param eta: A tensor with the current perturbation. :param norm: Order of the norm (mimics Numpy). Possible values: np.inf or 2. :param eps: Epsilon, bound of the perturbation. """ # Clipping perturbation eta to self.norm norm ball if norm not in [np.inf, 2]: raise ValueError("norm must be np.inf or 2.") axis = list(range(1, len(eta.shape))) avoid_zero_div = 1e-12 if norm == np.inf: eta = np.clip(eta, a_min=-eps, a_max=eps) elif norm == 2: # avoid_zero_div must go inside sqrt to avoid a divide by zero in the gradient through this operation norm = np.sqrt( np.maximum(avoid_zero_div, np.sum(np.square(eta), axis=axis, keepdims=True)) ) # We must *clip* to within the norm ball, not *normalize* onto the surface of the ball factor = np.minimum(1.0, np.divide(eps, norm)) eta = eta * factor return eta
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize, scale, translation, kernel: Callable, antialias: bool): inv_scale = 1. / scale # When downsampling the kernel should be scaled since we want to low pass # filter and interpolate, but when upsampling it should not be since we only # want to interpolate. kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1. sample_f = ((jnp.arange(output_size) + 0.5) * inv_scale - translation * inv_scale - 0.5) x = ( jnp.abs(sample_f[jnp.newaxis, :] - jnp.arange(input_size, dtype=sample_f.dtype)[:, jnp.newaxis]) / kernel_scale) weights = kernel(x) total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) weights = jnp.where( jnp.abs(total_weight_sum) > 1000. * np.finfo(np.float32).eps, jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)), 0) # Zero out weights where the sample location is completely outside the input # range. # Note sample_f has already had the 0.5 removed, hence the weird range below. input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 return jnp.where( jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
def func(arg): divider = 0. # NOTE: I made these floats numerator = 0. def body_fun(carry, x): divider, numerator = carry temp = np.dot(x, var) temp1 = np.sin(temp) temp2 = np.cos(temp) divid = np.add(temp1, temp2) divid = np.power(divid, 2) divid = np.sum(divid) numer = np.add(temp1, temp2) numer = np.sum(numer) numer = np.power(numer, 2) numerator = np.add(numer, numerator) divider = np.add(divider, divid) new_carry = divider, numerator return new_carry, () (divider, numerator), _ = lax.scan(body_fun, (divider, numerator), arg) divider = np.power(divider, 1 / 2) return np.log(np.divide(numerator, divider))
def forward(self, x, v_current): dV_tau = jnp.multiply(jnp.subtract(x, v_current), self.dt) dV = jnp.divide(dV_tau, self.tau_m) v_current = index_add(v_current, index[:], dV) spike_list = jnp.greater_equal(v_current, self.Vth).astype('int32') v_current = jnp.where(v_current >= self.Vth, 0, v_current * jnp.exp(-1 / self.tau_m)) return spike_list, v_current
def normal_logpdf(self, x, mu, sigma): # this is much faster than # norm.logpdf(x, loc=mu, scale=sigma) # https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions root2 = jnp.sqrt(2) root2pi = jnp.sqrt(2 * jnp.pi) prefactor = -jnp.log(sigma * root2pi) summand = -jnp.square(jnp.divide((x - mu), (root2 * sigma))) return prefactor + summand
def np_fn(input_np, v_current, gamma, tau_m, Vth, dt): v_current = ((input_np - v_current) / tau_m) * dt spike = np.greater_equal( v_current + np.multiply( np.divide(np.subtract(input_np, v_current), tau_m), dt), Vth).astype('float32') gamma += np.where(spike >= Vth, 1, 0) return spike, v_current, gamma
def attention_op(self, query, key, value, mask=None): d_model = query.shape[-1] scores = jnp.divide( jnp.matmul(query, key.transpose(0, 2, 1)), jnp.sqrt(d_model) ) if mask is not None: scores = jnp.matmul(scores, mask) attention = nn.softmax(scores, axis=-1) attention = jnp.matmul(attention, value) return attention
def get_K_stab(C, alpha, beta, reg): if isinstance(C, np.ndarray): # Faster exponent K = np.divide(C - alpha[:, None] - beta[None, :], -reg) return np.exp(K) elif isinstance(C, jax.interpreters.xla.DeviceArray): K = jnp.divide(C - alpha[:, None] - beta[None, :], -reg) return jnp.exp(K) else: raise NotImplementedError(f"The type {type(C)} is not supported!")
def objective(state, bparam, batch_input): """ Computes scalar objective. :param params: pytree PyTreeDef(list, [PyTreeDef(tuple, [*,*])]) :param bparam: pytree * :param inputs: pytree * :param outputs: pytree * :return: pytree (scalar) * """ result = 0.0 for (w, b) in state: result += np.mean( np.sum( np.divide(np.power(w, 4), 4.0) + bparam[0] * np.divide(np.power(w, 2), 2.0)) + np.sum( np.divide(np.power(b, 4), 4.0) + bparam[0] * np.divide(np.power(b, 2), 2.0))) return result
def amplitude(self, nu: ArrayLike) -> jnp.ndarray: """The amplitude of the glitch, :math:`a_\\mathrm{CZ} / \\nu^{-2}`. Args: nu (:term:`array_like`): Mode frequency, :math:`\\nu`. Returns: jax.numpy.ndarray: Base of the convective zone glitch amplitude. """ return jnp.divide(self._a, nu**2)
def eval_for(op): if op.op_name in ("IAdd", "IMul", "FAdd", "FMul", "FDiv"): x, y = op.args x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val) y_bc = broadcast_dims(op.all_idxs, y.idxs, y.atom.val) if op.op_name in ("IAdd", "FAdd"): return jnp.add(x_bc, y_bc) elif op.op_name in ("IMul", "FMul"): return jnp.multiply(x_bc, y_bc) if op.op_name in ("FDiv",): return jnp.divide(x_bc, y_bc) else: raise Exception("Not implemented: " + str(op.op_name)) elif op.op_name == "Iota": n, = op.size_args val = jnp.arange(n) val_bc = broadcast_dims(op.all_idxs, [], val) return val_bc elif op.op_name == "Id": x, = op.args x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val) return x_bc elif op.op_name == "Get": x, idx = op.args out_shape = [i.size for i in op.all_idxs] x_idxs_used = get_stack_idxs_used(op.all_idxs, x.idxs) leading_idx_arrays = [] for i, idx_used in enumerate(x_idxs_used): if idx_used: leading_idx_arrays.append(nth_iota(out_shape, i)) else: pass payload_idx_array = broadcast_dims(op.all_idxs, idx.idxs, idx.atom.val) out = x.atom.val[tuple(leading_idx_arrays) + (payload_idx_array,)] return out elif op.op_name == "IntToReal": x, = op.args real_val = jnp.array(x.atom.val, dtype="float32") x_bc = broadcast_dims(op.all_idxs, x.idxs, real_val) return x_bc elif op.op_name in ("FNeg", "INeg"): x, = op.args x_bc = broadcast_dims(op.all_idxs, x.idxs, jnp.negative(x.atom.val)) return x_bc elif op.op_name == "ThreeFry2x32": convert_64_to_32s = lambda x: np.array([x]).view(np.uint32) convert_32s_to_64 = lambda x: np.int64(np.array(x).view(np.int64).item()) x, y = op.args key, count = convert_64_to_32s(x.atom.val), convert_64_to_32s(y.atom.val) result = convert_32s_to_64(random.threefry_2x32(key, count)) x_bc = broadcast_dims(op.all_idxs, x.idxs, result) return x_bc else: raise Exception("Unrecognized op: {}".format(op.op_name))
def _case3(zagf): z, alpha, _, flag = zagf # Formula 59 of [1] z_div_a = np.divide(z, alpha) aa = alpha * alpha term1 = 1440 * alpha + 6 * z_div_a * (53 - 120 * z) - 65 * z_div_a * z_div_a + 3600 * z + 107 term2 = 1244160 * alpha * aa term3 = 1 + 24 * alpha + 288 * aa grad = term1 * term3 / term2 return z, alpha, grad, ~flag
def SUMO(x, rng, enc_params, dec_params, K, m=16, *args, **kwargs): # K = sampling_tail(rng) K_range = lax.iota(dtype=jnp.int32, size=K + m) + 1 rngs = random.split(rng, K + m) vec_iw_estimator = vmap(iw_estimator, in_axes=(None, 0, None, None)) log_iw = vec_iw_estimator(x, rngs, enc_params, dec_params) iwelbo_K = logcumsumexp(log_iw) - jnp.log(K_range) vec_reverse_cdf = vmap(reverse_cdf, in_axes=(0, )) inv_weights = jnp.divide(1., vec_reverse_cdf(K_range[m:])) return iwelbo_K[m - 1] + jnp.sum(inv_weights * (iwelbo_K[m:] - iwelbo_K[m - 1:-1]))
def round_coupling(coupling, a, b): """Projects a coupling matrix to the nearest matrix with marginals a and b. A finite number of sinkhorn iterations will always lead to a coupling P which does not satisfy the constraints that sum_j P[:, j] = a or sum_i P[i, :] = b. This differential rounding operation from algorithm 2 from Altschuler et al. ensures that you map tot he nearest matrix that does satisfy the constraints. Note: some implementations convert coupling, a and b to double precision before performing the algorithm. In case of instability, try that. Args: coupling: jnp.ndarray of shape [N, M]. Approximate coupling that results from for instance a Sinkhorn solver. a: jnp.ndarray of shape [N,]. The desired marginal of the rows of the coupling matrix. b: jnp.ndarray of shape [M,]. The desired marginal of the columns of the coupling matrix. Returns: r_coupling: jnp.ndarray of shape [N, M] such that r_coupling.sum(0) == b and r_coupling.sum(1) == a. """ a_div_coupling = jnp.divide(a, coupling.sum(1)) x = 1. - jax.nn.relu(1. - a_div_coupling) pp = x.reshape((-1, 1)) * coupling b_div_coupling = jnp.divide(b, coupling.sum(0)) y = 1. - jax.nn.relu(1. - b_div_coupling) pp = pp * y.reshape((1, -1)) err_a = a - pp.sum(1) err_b = b - pp.sum(0) kron_ab = err_a[:, jnp.newaxis] * err_b[jnp.newaxis, :] r_coupling = pp + kron_ab / jnp.sum(jnp.abs(err_a)) return r_coupling
def _rescale(centered_oks): """ compute ΔOₖ/√Sₖₖ and √Sₖₖ to do scale-invariant regularization (Becca & Sorella 2017, pp. 143) Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ) """ scale = (mpi.mpi_sum_jax( jnp.sum((centered_oks * centered_oks.conj()).real, axis=0, keepdims=True))[0]**0.5) centered_oks = jnp.divide(centered_oks, scale) scale = jnp.squeeze(scale, axis=0) return centered_oks, scale