def test_origin(self): sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) values = jt.array([[3, 6, 9], [3, 6, 9]]) ret = jt.searchsorted(sorted, values) assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret ret = jt.searchsorted(sorted, values, right=True) assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret sorted_1d = jt.array([1, 3, 5, 7, 9]) ret = jt.searchsorted(sorted_1d, values) assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret
def test_searchsorted_cpu(self): for i in range(1, 3): s = np.sort(np.random.rand(*((10, ) * i)), -1) v = np.random.rand(*((10, ) * i)) s_jt = jt.array(s) v_jt = jt.array(v) s_tc = torch.from_numpy(s) v_tc = torch.from_numpy(v) y_tc = torch.searchsorted(s_tc, v_tc, right=True) y_jt = jt.searchsorted(s_jt, v_jt, right=True) assert np.allclose(y_jt.numpy(), y_tc.data) y_jt = jt.searchsorted(s_jt, v_jt, right=False) y_tc = torch.searchsorted(s_tc, v_tc, right=False) assert np.allclose(y_jt.numpy(), y_tc.data)
def sample_pdf(bins, weights, N_samples, det=False): # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / jt.sum(weights, -1, keepdims=True) cdf = jt.cumsum(pdf, -1) cdf = jt.concat([jt.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = jt.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = jt.random(list(cdf.shape[:-1]) + [N_samples]) # Invert CDF inds = jt.searchsorted(cdf, u, right=True) below = jt.maximum(jt.zeros_like(inds - 1), inds - 1) above = jt.minimum((cdf.shape[-1] - 1) * jt.ones_like(inds), inds) inds_g = jt.stack([below, above], -1) # (batch, N_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = jt.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = jt.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom[denom < 1e-5] = 1.0 t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples