Exemplo n.º 1
0
def test_force(a_tensor, b_tensor):
    # a_eval, a_evec = np.linalg.eigh(a_tensor)
    # b_eval, b_evec = np.linalg.eigh(b_tensor)

    a_eval, a_evec = evp.dsyevv3(a_tensor)
    b_eval, b_evec = evp.dsyevv3(b_tensor)

    # print("ref w", a_eval)
    # print("test w", evp.dsyevc3(a_tensor))

    # print("ref v", a_evec)
    # print("test v", evp.dsyevv3(a_tensor))

    # assert 0

    r = np.matmul(np.transpose(a_evec), b_evec)
    I = np.eye(3)
    rI = r * I  # 3x3 -> 3x3
    pos = np.sum(rI, axis=-1)  # 3x3 -> 3
    neg = -np.sum(rI, axis=-1)  # 3x3 -> 3
    acos_pos = np.arccos(pos)  # 3 -> 3
    acos_neg = np.arccos(neg)  # 3 -> 3
    a = np.amin([acos_pos, acos_neg], axis=0)  # 2x3 -> 3
    a2 = a * a  # 3->3
    l = np.sum(a2)  # 3->1

    # derivatives, start backprop
    dl_da2 = np.ones(3)  # 1 x 3
    da2_da = 2 * a * np.eye(3)  # 3 x 3
    da_darg = np.stack(
        [np.eye(3) * (acos_pos < acos_neg),
         np.eye(3) * (acos_neg < acos_pos)])

    darg_dpn = np.stack([
        np.eye(3) * (-1 / np.sqrt(1 - pos * pos)),
        np.eye(3) * (-1 / np.sqrt(1 - neg * neg))
    ])

    dl_darg = np.matmul(np.matmul(dl_da2, da2_da), da_darg)
    dpos = dl_darg[0] * (-1 / np.sqrt(1 - pos * pos))
    dneg = dl_darg[1] * (-1 / np.sqrt(1 - neg * neg))
    dneg = -dneg

    dpn_dr = np.array([
        [[1, 1, 1], [0, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [1, 1, 1], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 0], [1, 1, 1]],
    ])

    # element wise
    dr = (np.matmul(dpos, dpn_dr) + np.matmul(dneg, dpn_dr)) * np.eye(3)

    dr_daevec = np.matmul(b_evec, dr.T)
    dr_dbevec = np.matmul(a_evec, dr.T)

    dl_datensor = grad_eigh(a_eval, a_evec, np.zeros_like(a_eval), dr_daevec)
    dl_dbtensor = grad_eigh(b_eval, b_evec, np.zeros_like(b_eval), dr_dbevec)

    return dl_datensor, dl_dbtensor
Exemplo n.º 2
0
def pmi_u(r):
    I = np.eye(3)

    loss = []
    for v, e in zip(r, I):
        a_pos = np.arccos(np.sum(v * e))  # norm is always 1
        a_neg = np.arccos(np.sum(-v * e))  # norm is always 1
        a = np.amin([a_pos, a_neg])
        loss.append(a * a)

    return np.sum(loss)
Exemplo n.º 3
0
def simplified_u(a_tensor, b_tensor):
    a_eval, a_evec = np.linalg.eigh(a_tensor)
    b_eval, b_evec = np.linalg.eigh(b_tensor)
    r = np.matmul(np.transpose(a_evec), b_evec)
    I = np.eye(3)
    rI = r * I  # 3x3 -> 3x3
    pos = np.sum(rI, axis=-1)
    neg = np.sum(-rI, axis=-1)
    acos_pos = np.arccos(pos)
    acos_neg = np.arccos(neg)
    # [a,b,c]
    # [d,e,f]
    # -------
    # [min(a,d), min(b,e), min(c,f)]
    a = np.amin([acos_pos, acos_neg], axis=0)
    return np.sum(a * a)
Exemplo n.º 4
0
def _arccos(x, do_backprop):
    if do_backprop:
        # https://github.com/google/jax/issues/654
        x = np.where(np.abs(x) >= 1, np.sign(x), x)
    else:
        x = np.clip(x, -1, 1)
    return np.arccos(x)
Exemplo n.º 5
0
def grassman_distance(y1, y2):
  """Grassman distance between subspaces spanned by Y1 and Y2."""
  q1, _ = jnp.linalg.qr(y1)
  q2, _ = jnp.linalg.qr(y2)

  _, sigma, _ = jnp.linalg.svd(q1.T @ q2)
  sigma = jnp.round(sigma, decimals=6)
  return jnp.linalg.norm(jnp.arccos(sigma))
Exemplo n.º 6
0
def angle(x):
    r1 = x[0,:]-x[1,:]
    r2 = x[2,:]-x[1,:]

    # costheta = np.dot(r1,r2) / np.linalg.norm(r1) / np.linalg.norm(r2)
    costheta = np.dot(r1,r2) / np.sqrt(np.dot(r1,r1) * np.dot(r2,r2))
    theta = np.arccos(costheta) 
    return theta
Exemplo n.º 7
0
def compute_grassman_distance(Y1, Y2):
    """Grassman distance between subspaces spanned by Y1 and Y2."""
    Q1, _ = jnp.linalg.qr(Y1)
    Q2, _ = jnp.linalg.qr(Y2)

    _, sigma, _ = jnp.linalg.svd(Q1.T @ Q2)
    sigma = jnp.round(sigma, decimals=6)
    return jnp.linalg.norm(jnp.arccos(sigma))
Exemplo n.º 8
0
def _von_mises_centered(key, concentration, shape, dtype):
    # Cutoff from TensorFlow probability
    # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
    s_cutoff_map = {
        jnp.dtype(jnp.float16): 1.8e-1,
        jnp.dtype(jnp.float32): 2e-2,
        jnp.dtype(jnp.float64): 1.2e-4,
    }
    s_cutoff = s_cutoff_map.get(dtype)

    r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2)
    rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
    s_exact = (1.0 + rho**2) / (2.0 * rho)

    s_approximate = 1.0 / concentration

    s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)

    def cond_fn(*args):
        """ check if all are done or reached max number of iterations """
        i, _, done, _, _ = args[0]
        return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))

    def body_fn(*args):
        i, key, done, _, w = args[0]
        uni_ukey, uni_vkey, key = random.split(key, 3)

        u = random.uniform(
            key=uni_ukey,
            shape=shape,
            dtype=concentration.dtype,
            minval=-1.0,
            maxval=1.0,
        )
        z = jnp.cos(jnp.pi * u)
        w = jnp.where(done, w,
                      (1.0 + s * z) / (s + z))  # Update where not done

        y = concentration * (s - w)
        v = random.uniform(key=uni_vkey,
                           shape=shape,
                           dtype=concentration.dtype)

        accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)

        return i + 1, key, accept | done, u, w

    init_done = jnp.zeros(shape, dtype=bool)
    init_u = jnp.zeros(shape)
    init_w = jnp.zeros(shape)

    _, _, done, u, w = lax.while_loop(
        cond_fun=cond_fn,
        body_fun=body_fn,
        init_val=(jnp.array(0), key, init_done, init_u, init_w),
    )

    return jnp.sign(u) * jnp.arccos(w)
Exemplo n.º 9
0
def pmi_restraints(conf, params, box, lamb, a_idxs, b_idxs, masses,
                   angle_force, com_force):

    a_com, a_tensor = inertia_tensor(conf[a_idxs], masses[a_idxs])
    b_com, b_tensor = inertia_tensor(conf[b_idxs], masses[b_idxs])

    a_eval, a_evec = np.linalg.eigh(a_tensor)  # already sorted
    b_eval, b_evec = np.linalg.eigh(b_tensor)  # already sorted

    r = np.matmul(np.transpose(a_evec), b_evec)
    I = np.eye(3)

    loss = []
    for v, e in zip(r, I):
        a_pos = np.arccos(np.sum(v * e))  # norm is always 1
        a_neg = np.arccos(np.sum(-v * e))  # norm is always 1
        a = np.amin([a_pos, a_neg])
        loss.append(a * a)
Exemplo n.º 10
0
def grassman_distance(Y1, Y2):  # pylint: disable=invalid-name
  """Grassman distance between subspaces spanned by Y1 and Y2."""
  Q1, _ = jnp.linalg.qr(Y1)  # pylint: disable=invalid-name
  Q2, _ = jnp.linalg.qr(Y2)  # pylint: disable=invalid-name

  _, sigma, _ = jnp.linalg.svd(Q1.T @ Q2)
  # sigma = jnp.clip(sigma, -1., 1.)
  sigma = jnp.round(sigma, decimals=6)
  return jnp.linalg.norm(jnp.arccos(sigma))
Exemplo n.º 11
0
def cartesian_to_spherical(x):
    r = jnp.sqrt(jnp.sum(x**2))
    denominators = jnp.sqrt(jnp.cumsum(x[::-1]**2)[::-1])[:-1]
    phi = jnp.arccos(x[:-1] / denominators)

    last_value = jnp.where(x[-1] >= 0, phi[-1], 2 * jnp.pi - phi[-1])
    phi = jax.ops.index_update(phi, -1, last_value)

    return jnp.hstack([r, phi])
Exemplo n.º 12
0
 def inv_kin(self, target=None):
     if target is None:
         target = self.target
     a = target[0]**2 + target[1]**2 - self.l[0]**2 - self.l[1]**2
     b = 2 * self.l[0] * self.l[1]
     q2 = np.arccos(a / b)
     c = np.arctan2(target[1], target[0])
     q1 = c - np.arctan2(self.l[1] * np.sin(q2),
                         (self.l[0] + self.l[1] * np.cos(q2)))
     return np.array([q1, q2])
Exemplo n.º 13
0
def jac_log(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """Computes the Jacobian of the logarithmic map on the sphere.

    Args:
        x: Origination points on the sphere.
        y: Destination points on the sphere.

    Returns:
        d: The Jacobian of the logarithmic map.

    """
    r = (x * y).sum(axis=-1, keepdims=True)
    v = r * jnp.arccos(r) / jnp.power(1 - jnp.square(r), 1.5) - 1 / (1 - jnp.square(r))
    a = v[..., jnp.newaxis] * (y - r * x)[..., jnp.newaxis] * x[..., jnp.newaxis, :]
    acr = jnp.arccos(r)
    b = (acr / jnp.sin(acr))[..., jnp.newaxis]
    c = jnp.eye(x.shape[-1])[jnp.newaxis] - x[..., jnp.newaxis] * x[..., jnp.newaxis, :]
    d = a + b * c
    return d
def table_ambisonics_order_vs_rE(max_order=20):
    """Return a dataframe with rE as a function of order."""
    order = np.arange(1, max_order + 1, dtype=np.int32)
    rE3 = np.array(list(map(shelf.max_rE_3d, order)))
    drE3 = np.append(np.nan, rE3[1:] - rE3[:-1])

    rE2 = np.array(list(map(shelf.max_rE_2d, order)))
    drE2 = np.append(np.nan, rE2[1:] - rE2[:-1])

    df = pd.DataFrame(np.column_stack((
        order,
        rE2,
        100 * drE2 / rE2,
        2 * np.arccos(rE2) * 180 / π,
        rE3,
        100 * drE3 / rE3,
        2 * np.arccos(rE3) * 180 / π,
    )),
                      columns=('order', '2D', '% change', 'asw', '3D',
                               '% change', 'asw'))
    return df
Exemplo n.º 15
0
def harmonic_angle(conf, params, box, angle_idxs, param_idxs, cos_angles=True):
    """
    Compute the harmonic bond energy given a collection of molecules.

    This implements a harmonic angle potential: V(t) = k*(t - t0)^2 or V(t) = k*(cos(t)-cos(t0))^2

    Parameters:
    -----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    angle_idxs: shape [num_angles, 3] np.array
        each element (a, b, c) is a unique angle in the conformation. atom b is defined
        to be the middle atom.

    param_idxs: shape [num_angles, 2] np.array
        each element (k_idx, t_idx) maps into params for angle constants and ideal angles

    cos_angles: True (default)
        if True, then this instead implements V(t) = k*(cos(t)-cos(t0))^2. This is far more
        numerically stable when the angle is pi.

    """
    ci = conf[angle_idxs[:, 0]]
    cj = conf[angle_idxs[:, 1]]
    ck = conf[angle_idxs[:, 2]]

    kas = params[param_idxs[:, 0]]
    a0s = params[param_idxs[:, 1]]

    vij = delta_r(ci, cj, box)
    vjk = delta_r(ck, cj, box)

    top = np.sum(np.multiply(vij, vjk), -1)
    bot = np.linalg.norm(vij, axis=-1) * np.linalg.norm(vjk, axis=-1)

    tb = top / bot

    # (ytz): we used the squared version so that we make this energy being strictly positive
    if cos_angles:
        energies = kas / 2 * np.power(tb - np.cos(a0s), 2)
    else:
        angle = np.arccos(tb)
        energies = kas / 2 * np.power(angle - a0s, 2)

    return np.sum(energies, -1)  # reduce over all angles
Exemplo n.º 16
0
def vector_angle(a, b):
    """
    Find the angle between two vectors.
    """
    a_mod = np.linalg.norm(a)
    b_mod = np.linalg.norm(b)
    if a.ndim == 2 & b.ndim == 2:
        dot = np.einsum('ij,ik->i', a / a_mod, b / b_mod)
    elif a.ndim == 1 & b.ndim == 1:
        dot = np.dot(a / a_mod, b / b_mod)
    else:
        raise Exception('Input must have 1 or 2 dimensions.')
    angle = np.arccos(dot)
    return angle
Exemplo n.º 17
0
def get_rotation_pytree(src: Any, dst: Any) -> Any:
    """
    Takes two n-dimensional vectors/Pytree and returns an
    nxn rotation matrix mapping cjax to dst.
    Raises Value Error when unsuccessful.
    """
    def __assert_rotation(R):
        if R.ndim != 2:
            print("R must be a matrix")
        a, b = R.shape
        if a != b:
            print("R must be square")
        if (not jnp.isclose(
                jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5)
            ) or (not jnp.isclose(
                jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)):
            print("R is not diagonal")

    if not pytree_shape_array_equal(src, dst):
        print("cjax and dst must be 1-dimensional arrays with the same shape.")

    x = pytree_normalized(src)
    y = pytree_normalized(dst)
    n = len(dst)

    # compute angle between x and y in their spanning space
    theta = jnp.arccos(jnp.dot(
        x, y))  # they are normalized so there is no denominator
    if jnp.isclose(theta, 0):
        print("x and y are co-linear")
    # construct the 2d rotation matrix connecting x to y in their spanning space
    R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                   [jnp.sin(theta), jnp.cos(theta)]])
    __assert_rotation(R)
    # get projections onto Span<x,y> and its orthogonal complement
    u = x
    v = pytree_normalized(pytree_sub(y, (jnp.dot(u, y) * u)))
    P = jnp.outer(u, u.T) + jnp.outer(
        v, v.T)  # projection onto 2d space spanned by x and y
    Q = jnp.eye(
        n) - P  # projection onto the orthogonal complement of Span<x,y>
    # lift the rotation matrix into the n-dimensional space
    uv = jnp.hstack((u[:, None], v[:, None]))

    R = Q + jnp.dot(uv, jnp.dot(R, uv.T))
    __assert_rotation(R)
    if jnp.any(jnp.logical_not(jnp.isclose(jnp.dot(R, x), y, rtol=0.25))):
        print("Rotation matrix did not work")
    return R
Exemplo n.º 18
0
def logarithmic(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """Computes the logarithmic map on the sphere.

    Args:
        x: Origination points on the sphere.
        y: Destination points on the sphere.

    Returns:
        lg: The tangent vector in the tangent space of the origination point that
            would produce the destination under the exponential map.

    """
    xy = (x * y).sum(axis=-1, keepdims=True)
    v = jnp.arccos(xy)
    lg = v / jnp.sin(v) * (y - xy * x)
    return lg
Exemplo n.º 19
0
def _compute_angle(p0: jnp.ndarray, p1: jnp.ndarray,
                   p2: jnp.ndarray) -> jnp.float32:
  """Compute the angle centered at `p1` between the other two points."""

  a = p1 - p0
  da = np.linalg.norm(a)

  b = p1 - p2
  db = np.linalg.norm(b)

  c = p0 - p2
  dc = np.linalg.norm(c)

  x = (da**2 + db**2 - dc**2) / (2.0 * da * db)

  angle = jnp.arccos(x) * 180.0 / jnp.pi

  return min(angle, 360 - angle)
Exemplo n.º 20
0
def arccos(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.arccos(x))
Exemplo n.º 21
0
def arccos(a: Numeric):
    return jnp.arccos(a)
Exemplo n.º 22
0
def harmonic_angle(conf,
                   params,
                   box,
                   lamb,
                   angle_idxs,
                   lamb_mult=None,
                   lamb_offset=None,
                   cos_angles=True):
    """
    Compute the harmonic angle energy given a collection of molecules.

    This implements a harmonic angle potential:
        V(t) = k*(t - t0)^2
            if cos_angles=False
        or
        V(t) = k*(cos(t)-cos(t0))^2
            if cos_angles=True


    Parameters:
    -----------
    conf: shape [num_atoms, 3] np.array
        atomic coordinates

    params: shape [num_params,] np.array
        unique parameters

    box: shape [3, 3] np.array
        periodic boundary vectors, if not None

    lamb: float
        alchemical lambda parameter, linearly rescaled

    lamb_mult: None, or broadcastable to angle_idxs.shape[0]
        prefactor = (lamb_offset + lamb_mult * lamb)

    lamb_offset: None, or broadcastable to angle_idxs.shape[0]
        prefactor = (lamb_offset + lamb_mult * lamb)

    angle_idxs: shape [num_angles, 3] np.array
        each element (a, b, c) is a unique angle in the conformation. atom b is defined
        to be the middle atom.

    cos_angles: True (default)
        if True, then this instead implements V(t) = k*(cos(t)-cos(t0))^2. This is far more
        numerically stable when the angle is pi.

    Notes:
    ------
    * lamb argument unused
    """
    if lamb_mult is None or lamb_offset is None or lamb is None:
        assert lamb_mult is None
        assert lamb_offset is None
        prefactor = 1.0
    else:
        assert lamb_mult is not None
        assert lamb_offset is not None
        prefactor = lamb_offset + lamb_mult * lamb

    ci = conf[angle_idxs[:, 0]]
    cj = conf[angle_idxs[:, 1]]
    ck = conf[angle_idxs[:, 2]]

    kas = params[:, 0]
    a0s = params[:, 1]

    vij = ci - cj
    vjk = ck - cj

    top = np.sum(np.multiply(vij, vjk), -1)
    bot = np.linalg.norm(vij, axis=-1) * np.linalg.norm(vjk, axis=-1)

    tb = top / bot

    # (ytz): we use the squared version so that the energy is strictly positive
    if cos_angles:
        energies = prefactor * kas / 2 * np.power(tb - np.cos(a0s), 2)
    else:
        angle = np.arccos(tb)
        energies = prefactor * kas / 2 * np.power(angle - a0s, 2)

    return np.sum(energies, -1)  # reduce over all angles
Exemplo n.º 23
0
 def angle(P):
   return np.arccos(cosAngle(P))
Exemplo n.º 24
0
def _asinc_naive(x):
    return jnp.arccos(x) / jnp.sqrt(1 - x**2)
Exemplo n.º 25
0
 def get_index(x):
     x = 2 * (grid.lower - x.flatten()) / grid.size + 1
     idx = (grid.shape * np.arccos(x)) // np.pi
     idx = np.nan_to_num(idx, nan=grid.shape)
     return (*np.flip(np.uint32(idx)), )
Exemplo n.º 26
0
def calc_angles(xyz: np.ndarray, struct_descr_dict):
    r_ij = xyz[struct_descr_dict['angles'][:, 1]] - xyz[struct_descr_dict['angles'][:, 0]]
    r_kj = xyz[struct_descr_dict['angles'][:, 1]] - xyz[struct_descr_dict['angles'][:, 2]]
    cos = np.sum(r_ij * r_kj, axis=1) / (np.linalg.norm(r_ij, axis=1) * np.linalg.norm(r_kj, axis=1))
    return np.arccos(np.clip(cos, -1, 1))
Exemplo n.º 27
0
def safe_acos(t, eps=1e-8):
    """A safe version of arccos which avoids evaluating at -1 or 1."""
    return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps))
Exemplo n.º 28
0
def _top_k(input, k=1, sorted=True, name=None):  # pylint: disable=unused-argument,redefined-builtin
    raise NotImplementedError


# --- Begin Public Functions --------------------------------------------------

abs = utils.copy_docstring(  # pylint: disable=redefined-builtin
    tf.math.abs,
    lambda x, name=None: np.abs(x))

accumulate_n = utils.copy_docstring(
    tf.math.accumulate_n,
    lambda inputs, shape=None, tensor_dtype=None, name=None: (  # pylint: disable=g-long-lambda
        sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype))))

acos = utils.copy_docstring(tf.math.acos, lambda x, name=None: np.arccos(x))

acosh = utils.copy_docstring(tf.math.acosh, lambda x, name=None: np.arccosh(x))

add = utils.copy_docstring(tf.math.add, lambda x, y, name=None: np.add(x, y))

add_n = utils.copy_docstring(
    tf.math.add_n, lambda inputs, name=None: sum(map(np.array, inputs)))

angle = utils.copy_docstring(tf.math.angle,
                             lambda input, name=None: np.angle(input))

argmax = utils.copy_docstring(
    tf.math.argmax,
    lambda input, axis=None, output_type=tf.int64, name=None: (  # pylint: disable=g-long-lambda
        np.argmax(input, axis=0 if axis is None else _astuple(axis)).astype(
Exemplo n.º 29
0
def _angle(pos: jnp.ndarray, indices: jnp.ndarray,
           tvecs: jnp.ndarray) -> float:
    dx1 = -(pos[indices[1]] - pos[indices[0]] + tvecs[0])
    dx2 = pos[indices[2]] - pos[indices[1]] + tvecs[1]
    return jnp.arccos(dx1 @ dx2 /
                      (jnp.linalg.norm(dx1) * jnp.linalg.norm(dx2)))
Exemplo n.º 30
0
    def test_sorted_piecewise_constant_pdf_train_mode(self):
        """Test that piecewise-constant sampling reproduces its distribution."""
        batch_size = 4
        num_bins = 16
        num_samples = 1000000
        precision = 1e5
        rng = random.PRNGKey(20202020)

        # Generate a series of random PDFs to sample from.
        data = []
        for _ in range(batch_size):
            rng, key = random.split(rng)
            # Randomly initialize the distances between bins.
            # We're rolling our own fixed precision here to make cumsum exact.
            bins_delta = jnp.round(precision * jnp.exp(
                random.uniform(
                    key, shape=(num_bins + 1, ), minval=-3, maxval=3)))

            # Set some of the bin distances to 0.
            rng, key = random.split(rng)
            bins_delta *= random.uniform(key, shape=bins_delta.shape) < 0.9

            # Integrate the bins.
            bins = jnp.cumsum(bins_delta) / precision
            rng, key = random.split(rng)
            bins += random.normal(key) * num_bins / 2
            rng, key = random.split(rng)

            # Randomly generate weights, allowing some to be zero.
            weights = jnp.maximum(
                0,
                random.uniform(key, shape=(num_bins, ), minval=-0.5,
                               maxval=1.))
            gt_hist = weights / weights.sum()
            data.append((bins, weights, gt_hist))

        # Tack on an "all zeros" weight matrix, which is a common cause of NaNs.
        weights = jnp.zeros_like(weights)
        gt_hist = jnp.ones_like(gt_hist) / num_bins
        data.append((bins, weights, gt_hist))

        bins, weights, gt_hist = [jnp.stack(x) for x in zip(*data)]

        for randomized in [True, False]:
            rng, key = random.split(rng)
            # Draw samples from the batch of PDFs.
            samples = math.sorted_piecewise_constant_pdf(
                key,
                bins,
                weights,
                num_samples,
                randomized,
            )
            self.assertEqual(samples.shape[-1], num_samples)

            # Check that samples are sorted.
            self.assertTrue(jnp.all(samples[..., 1:] >= samples[..., :-1]))

            # Verify that each set of samples resembles the target distribution.
            for i_samples, i_bins, i_gt_hist in zip(samples, bins, gt_hist):
                i_hist = jnp.float32(jnp.histogram(i_samples,
                                                   i_bins)[0]) / num_samples
                i_gt_hist = jnp.array(i_gt_hist)

                # Merge any of the zero-span bins until there aren't any left.
                while jnp.any(i_bins[:-1] == i_bins[1:]):
                    j = int(jnp.where(i_bins[:-1] == i_bins[1:])[0][0])
                    i_hist = jnp.concatenate([
                        i_hist[:j],
                        jnp.array([i_hist[j] + i_hist[j + 1]]), i_hist[j + 2:]
                    ])
                    i_gt_hist = jnp.concatenate([
                        i_gt_hist[:j],
                        jnp.array([i_gt_hist[j] + i_gt_hist[j + 1]]),
                        i_gt_hist[j + 2:]
                    ])
                    i_bins = jnp.concatenate([i_bins[:j], i_bins[j + 1:]])

                # Angle between the two histograms in degrees.
                angle = 180 / jnp.pi * jnp.arccos(
                    jnp.minimum(
                        1.,
                        jnp.mean((i_hist * i_gt_hist) / jnp.sqrt(
                            jnp.mean(i_hist**2) * jnp.mean(i_gt_hist**2)))))
                # Jensen-Shannon divergence.
                m = (i_hist + i_gt_hist) / 2
                js_div = jnp.sum(
                    sp.special.kl_div(i_hist, m) +
                    sp.special.kl_div(i_gt_hist, m)) / 2
                self.assertLessEqual(angle, 0.5)
                self.assertLessEqual(js_div, 1e-5)