def test_origin(self):
        sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
        values = jt.array([[3, 6, 9], [3, 6, 9]])
        ret = jt.searchsorted(sorted, values)
        assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret

        ret = jt.searchsorted(sorted, values, right=True)
        assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret

        sorted_1d = jt.array([1, 3, 5, 7, 9])
        ret = jt.searchsorted(sorted_1d, values)
        assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret
    def test_searchsorted_cpu(self):
        for i in range(1, 3):
            s = np.sort(np.random.rand(*((10, ) * i)), -1)
            v = np.random.rand(*((10, ) * i))
            s_jt = jt.array(s)
            v_jt = jt.array(v)
            s_tc = torch.from_numpy(s)
            v_tc = torch.from_numpy(v)

            y_tc = torch.searchsorted(s_tc, v_tc, right=True)
            y_jt = jt.searchsorted(s_jt, v_jt, right=True)
            assert np.allclose(y_jt.numpy(), y_tc.data)
            y_jt = jt.searchsorted(s_jt, v_jt, right=False)
            y_tc = torch.searchsorted(s_tc, v_tc, right=False)
            assert np.allclose(y_jt.numpy(), y_tc.data)
Exemple #3
0
def sample_pdf(bins, weights, N_samples, det=False):
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / jt.sum(weights, -1, keepdims=True)
    cdf = jt.cumsum(pdf, -1)
    cdf = jt.concat([jt.zeros_like(cdf[..., :1]), cdf],
                    -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = jt.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = jt.random(list(cdf.shape[:-1]) + [N_samples])

    # Invert CDF
    inds = jt.searchsorted(cdf, u, right=True)
    below = jt.maximum(jt.zeros_like(inds - 1), inds - 1)
    above = jt.minimum((cdf.shape[-1] - 1) * jt.ones_like(inds), inds)
    inds_g = jt.stack([below, above], -1)  # (batch, N_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom[denom < 1e-5] = 1.0
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples