Beispiel #1
0
def f_rsh(rho, omega, polarized=False, use_jax=True):
    """Enchancement factor for evaluating short-range semilocal exchange.

  10.1063/1.4952647 Eq. 11.

  Args:
    rho: Float numpy array with shape (num_grids,), the electron density.
    omega: Float, the range seperation parameter.
    polarized: Boolean, whether the system is spin polarized.
    use_jax: Boolean, if True, use jax.numpy for calculations, otherwise use
      numpy.

  Returns:
    Float numpy array with shape (num_grids,), the RSH enhancement factor.
  """
    if use_jax:
        np = jnp
        special = jax.scipy.special
    else:
        np = onp
        special = scipy.special
    spin_factor = 1 if polarized else 2
    # Fermi wave vector
    kf = (6 * jnp.pi**2 * rho / spin_factor + utils.EPSILON)**(1 / 3)
    a = omega / kf + utils.EPSILON  # variable a in Eq. 11
    return (1 - 2 / 3 * a *
            (2 * jnp.pi**(1 / 2) * special.erf(1 / a) - 3 * a + a**3 +
             (2 * a - a**3) * np.exp(-1 / a**2)))
Beispiel #2
0
 def __init__(self,
              link='probit'):
     super().__init__()
     if link == 'logit':
         self.link_fn = lambda f: 1 / (1 + np.exp(-f))
         self.dlink_fn = lambda f: np.exp(f) / (1 + np.exp(f)) ** 2
         self.link = link
     elif link == 'probit':
         jitter = 1e-3
         self.link_fn = lambda f: 0.5 * (1.0 + erf(f / np.sqrt(2.0))) * (1 - 2 * jitter) + jitter
         self.dlink_fn = lambda f: grad(self.link_fn)(np.squeeze(f)).reshape(-1, 1)
         self.link = link
     else:
         raise NotImplementedError('link function not implemented')
     self.name = 'Bernoulli'
Beispiel #3
0
def _projected_normal_log_prob_3(concentration, value):
    def _dot(x, y):
        return (x[..., None, :] @ y[..., None])[..., 0, 0]

    # We integrate along a ray, factorizing the integrand as a product of:
    # a truncated normal distribution over coordinate t parallel to the ray, and
    # a bivariate normal distribution over coordinate r perpendicular to the ray.
    t = _dot(concentration, value)
    t2 = t * t
    r2 = _dot(concentration, concentration) - t2
    perp_part = (-0.5) * r2 - math.log(2 * math.pi)

    # This is the log of a definite integral, computed by mathematica:
    # Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}]
    # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2
    para_part = jnp.log(t * jnp.exp((-0.5) * t2) / (2 * math.pi)**0.5 +
                        (1 + t2) * (1 + erf(t * 0.5**0.5)) / 2)

    return para_part + perp_part
Beispiel #4
0
def ewald_energy(conf, box, charges, scale_matrix, cutoff, alpha, kmax):
    eij = pairwise_energy(conf, box, charges, cutoff)

    assert cutoff is not None

    # 1. Assume scale matrix is not used at all (no exceptions, no exclusions)
    # 1a. Direct Space
    eij_direct = eij * erfc(alpha*eij)
    eij_direct = ONE_4PI_EPS0*np.sum(eij_direct)/2

    # 1b. Reciprocal Space
    eij_recip = reciprocal_energy(conf, box, charges, alpha, kmax)

    # 2. Remove over estimated scale matrix contribution scaled by erf
    eij_offset = (1-scale_matrix) * eij
    eij_offset *= erf(alpha*eij_offset)
    eij_offset = ONE_4PI_EPS0*np.sum(eij_offset)/2

    return eij_direct + eij_recip - eij_offset - self_energy(conf, charges, alpha)
Beispiel #5
0
    def erf(self, tensor_in):
        """
        The error function of complex argument.

        Example:

            >>> import pyhf
            >>> pyhf.set_backend("jax")
            >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
            >>> pyhf.tensorlib.erf(a)
            DeviceArray([-0.99532227, -0.84270079,  0.        ,  0.84270079,
                          0.99532227], dtype=float64)

        Args:
            tensor_in (:obj:`tensor`): The input tensor object

        Returns:
            JAX ndarray: The values of the error function at the given points.
        """
        return special.erf(tensor_in)
Beispiel #6
0
def _erf(x, **kwargs):
    return erf(x)
Beispiel #7
0
def erf(a: Numeric):
    return jsps.erf(a)
Beispiel #8
0
 def fn(x):
     return a * erf(b * x) + c
Beispiel #9
0
def gelu(x):
    return x * erf(x)
Beispiel #10
0
def phi(dtmp, sigma2tmp): #cumulative distribution
    cones = jnp.ones((clen))
    return 0.5*( cones + jscisp.erf( jnp.divide(dtmp, jnp.sqrt(2.0*sigma2tmp) ) ) )