Esempio n. 1
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # NB: because domain and codomain are two spaces with different dimensions, determinant of
        # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the
        # flatten lower triangular part of `y`.

        # stick_breaking_logdet = log(y / r) = log(z_cumprod)  (modulo right shifted)
        z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
        # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
        z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
        stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1)

        tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1)
        return stick_breaking_logdet + tanh_logdet
Esempio n. 2
0
    def fetch_minibatch(self):  # Generate time + a Brownian motion
        T = self.T
        M = self.M
        N = self.N
        D = self.D

        Dt = jnp.zeros((M, N + 1, 1))  # M x (N+1) x 1
        DW = jnp.zeros((M, N + 1, D))  # M x (N+1) x D

        dt = T / N

        #Dt[:, 1:, :] = dt
        new_Dt = index_update(Dt, index[:, 1:, :], 1.)

        #DW[:, 1:, :] = jnp.sqrt(dt) * jnp.random.normal(size=(M, N, D))
        new_DW = index_update(DW, index[:, 1:, :], 1.)

        t = jnp.cumsum(new_Dt, axis=1)  # M x (N+1) x 1
        W = jnp.cumsum(new_DW, axis=1)  # M x (N+1) x D
        # t = torch.from_numpy(t).float().to(self.device) <- cancel these out so stays as numpy
        # W = torch.from_numpy(W).float().to(self.device) <- cancel these out so stays as numpy

        return t, W
Esempio n. 3
0
def compute_alpha_weights(density, t_vals, dirs):
    """Helper function for computing alpha compositing weights."""
    t_dists = t_vals[Ellipsis, 1:] - t_vals[Ellipsis, :-1]
    delta = t_dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1)
    density_delta = density * delta

    alpha = 1 - jnp.exp(-density_delta)
    trans = jnp.exp(-jnp.concatenate([
        jnp.zeros_like(density_delta[Ellipsis, :1]),
        jnp.cumsum(density_delta[Ellipsis, :-1], axis=-1)
    ],
                                     axis=-1))
    weights = alpha * trans
    return weights, alpha, trans, delta
Esempio n. 4
0
 def _inverse(self, y):
     # inverse stick-breaking
     z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
     pad_width = [(0, 0)] * y.ndim
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = jnp.pad(z1m_cumprod[..., :-1],
                                   pad_width,
                                   mode="constant",
                                   constant_values=1.)
     t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt(
         matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
     # inverse of tanh
     x = jnp.log((1 + t) / (1 - t)) / 2
     return x
Esempio n. 5
0
def multinomial(rng, logits):
    """Draws samples from a multinomial distribution given by logits.

  Args:
    rng: A JAX PRNGKey.
    logits: array with unnormalized log-probabilities in last axis.

  Returns:
    Array with sampled categories in last axis.
  """
    probs = jax.nn.softmax(logits)
    cum_probs = jnp.cumsum(probs, axis=-1)
    uniform_variates = jax.random.uniform(rng, logits.shape[:-1] + (1, ))
    return jnp.argmin(uniform_variates > cum_probs, axis=-1)
Esempio n. 6
0
def piecewise_constant_pdf(key, bins, weights, num_coarse_samples,
                           use_stratified_sampling):
    """Piecewise-Constant PDF sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, n_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, n_bins].
    num_coarse_samples: int, the number of samples.
    use_stratified_sampling: bool, use use_stratified_sampling samples.

  Returns:
    z_samples: jnp.ndarray(float32), [batch_size, num_coarse_samples].
  """
    eps = 1e-5

    # Get pdf
    weights += eps  # prevent nans
    pdf = weights / weights.sum(axis=-1, keepdims=True)
    cdf = jnp.cumsum(pdf, axis=-1)
    cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf],
                          axis=-1)

    # Take uniform samples
    if use_stratified_sampling:
        u = random.uniform(key, list(cdf.shape[:-1]) + [num_coarse_samples])
    else:
        u = jnp.linspace(0., 1., num_coarse_samples)
        u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_coarse_samples])

    # Invert CDF. This takes advantage of the fact that `bins` is sorted.
    mask = (u[..., None, :] >= cdf[..., :, None])

    def minmax(x):
        x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
        x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
        x0 = jnp.minimum(x0, x[..., -2:-1])
        x1 = jnp.maximum(x1, x[..., 1:2])
        return x0, x1

    bins_g0, bins_g1 = minmax(bins)
    cdf_g0, cdf_g1 = minmax(cdf)

    denom = (cdf_g1 - cdf_g0)
    denom = jnp.where(denom < eps, 1., denom)
    t = (u - cdf_g0) / denom
    z_samples = bins_g0 + t * (bins_g1 - bins_g0)

    # Prevent gradient from backprop-ing through samples
    return lax.stop_gradient(z_samples)
Esempio n. 7
0
  def __call__(self, inputs, prev_state):
    current_input, return_target = inputs

    em_state, core_state = prev_state
    (counter, memories) = em_state

    if self._apply_core_to_input:
      current_input, core_state = self._core(current_input, core_state)

    # Synthetic return for the current state
    synth_return = jnp.squeeze(self._synthetic_return(current_input), -1)

    # Current state bias term
    bias = self._bias(current_input)

    # Gate computed from current state
    gate = self._gate(current_input)

    # When counter > capacity, mask will be all ones
    mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1)
    mask = jnp.expand_dims(mask, axis=2)

    # Synthetic returns for each state in memory
    past_synth_returns = hk.BatchApply(self._synthetic_return)(memories)

    # Sum of synthetic returns from previous states
    sr_sum = jnp.sum(past_synth_returns * mask, axis=1)

    prediction = jnp.squeeze(sr_sum * gate + bias, -1)
    sr_loss = self._loss(prediction, return_target)

    augmented_return = jax.lax.stop_gradient(
        self._alpha * synth_return + self._beta * return_target)

    # Write current state to memory
    _, em_state = self._em(current_input, em_state)

    if not self._apply_core_to_input:
      output, core_state = self._core(current_input, core_state)
    else:
      output = current_input

    output = SRCoreWrapperOutput(
        output=output,
        synthetic_return=synth_return,
        augmented_return=augmented_return,
        sr_loss=sr_loss,
    )
    return output, (em_state, core_state)
Esempio n. 8
0
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - np.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key,
                              size,
                              minval=upper_bound,
                              maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, np.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones(
            (size[-1], )), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key,
                           p=np.ones((n, )) / n,
                           n=constraint.upper_bound,
                           shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1)) + 1e-2
    elif isinstance(constraint, constraints._CorrMatrix):
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1))
        return np.matmul(cholesky, np.swapaxes(cholesky, -2, -1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return random.uniform(key, size)
    elif isinstance(constraint, constraints._PositiveDefinite):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._OrderedVector):
        x = np.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Esempio n. 9
0
def rtrun_direct(dtau, S):
    """Radiative Transfer using direct integration.

    Note:
        Use dtau/mu instead of dtau when you want to use non-unity, where mu=cos(theta)

    Args:
        dtau: opacity matrix
        S: source matrix [N_layer, N_nus]

    Returns:
        flux in the unit of [erg/cm2/s/cm-1] if using piBarr as a source function.
    """
    taupmu = jnp.cumsum(dtau, axis=0)
    return jnp.sum(S * jnp.exp(-taupmu) * dtau, axis=0)
Esempio n. 10
0
def _entmax15(x, axis):
    x = x / 2

    # get indices of elements in the right axis
    # and reshape to allow broadcasting to other dimensions
    idxs = jnp.arange(x.shape[axis]) + 1
    idxs = reshape_to_broadcast(idxs, x.shape, axis)

    # calculate number of elements that belong to the support
    sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis)
    cum_x = jnp.cumsum(sorted_x, axis=axis)
    cum_x_sq = jnp.cumsum(sorted_x**2, axis=axis)
    mean = cum_x / idxs
    var = cum_x_sq - (mean**2) * idxs
    delta = (1 - var) / idxs
    delta = jnp.maximum(delta, 0)  # TODO: understand why we need this
    thresholds = mean - jnp.sqrt(delta)
    k = jnp.sum(jnp.where(thresholds <= sorted_x, 1, 0),
                axis=axis,
                keepdims=True)

    # calculate threshold and project to simplex
    threshold = jnp.take_along_axis(thresholds, k - 1, axis=axis)
    return jnp.maximum(x - threshold, 0)**2
def generate_data():
    T = 1000
    tec = jnp.cumsum(15. * random.normal(random.PRNGKey(0), shape=(T, )))
    TEC_CONV = -8.4479745e6  # mTECU/Hz
    freqs = jnp.linspace(121e6, 168e6, 24)
    phase = tec[:, None] / freqs * TEC_CONV
    Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=1)
    Y_obs = Y + 0.75 * random.normal(random.PRNGKey(1), shape=Y.shape)
    # Y_obs[500:550:2, :] += 3. * random.normal(random.PRNGKey(1),shape=Y[500:550:2, :].shape)
    Sigma = 0.5**2 * jnp.eye(48)
    Omega = jnp.diag(jnp.array([30.]))**2
    mu0 = jnp.zeros(1)
    Gamma0 = jnp.diag(jnp.array([200.]))**2
    amp = jnp.ones_like(phase)
    return Gamma0, Omega, Sigma, T, Y_obs, amp, mu0, tec, freqs
Esempio n. 12
0
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
    mat = jnp.asarray(mat)
    assert mat.ndim == 2
    m = mat.shape[0]

    row, col = jnp.nonzero(mat, size=nnz)
    data = mat[row, col]

    true_nonzeros = jnp.arange(nnz) < (mat != 0).sum()
    data = jnp.where(true_nonzeros, data, 0)
    row = jnp.where(true_nonzeros, row, m)
    indices = col.astype(index_dtype)
    indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set(
        jnp.cumsum(jnp.bincount(row, length=m)))
    return data, indices, indptr
Esempio n. 13
0
def spline_unconstrained_transform(thetax: jnp.ndarray, thetay: jnp.ndarray,
                                   thetad: jnp.ndarray) -> jnp.ndarray:
    """Transform the unconstrained parameters of the spline transform into their
    constrained counterparts.

    Args:
        thetax: Unconstrained x-coordinates of the spline intervals.
        thetay: Unconstrained y-coordinates of the spline intervals.
        thetad: Unconstrained derivatives at internal points.

    Returns:
        xk: The x-coordinates of the intervals on which the rational quadratics
            are defined.
        yk: The y-coordinates of the destination intervals of the rational
            quadratic transforms.
        delta: Derivatives at internal points.

    """
    xk = jnp.atleast_2d(jnp.cumsum(2 * nn.softmax(thetax), axis=-1) - 1.)
    xk = jnp.hstack((-jnp.ones((xk.shape[0], 1)), xk))
    yk = jnp.atleast_2d(jnp.cumsum(2 * nn.softmax(thetay), axis=-1) - 1.)
    yk = jnp.hstack((-jnp.ones((yk.shape[0], 1)), yk))
    delta = nn.softplus(thetad)
    return jnp.squeeze(xk), jnp.squeeze(yk), jnp.squeeze(delta)
Esempio n. 14
0
def _ravel_list(*leaves):
    leaves_metadata = tree_map(
        lambda l: pytree_metadata(np.ravel(l), np.shape(l), np.size(l),
                                  lax.dtype(l)), leaves)
    leaves_idx = np.cumsum(
        np.array((0, ) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [
            np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                       m.shape).astype(m.dtype)
            for i, m in enumerate(leaves_metadata)
        ]

    return np.concatenate([m.flat for m in leaves_metadata]), unravel_list
Esempio n. 15
0
    def _inverse(self, y):
        # inverse stick-breaking
        remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
        pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)]
        remainder = jnp.pad(remainder,
                            pad_width,
                            mode="constant",
                            constant_values=1.0)
        finfo = jnp.finfo(y.dtype)
        remainder = jnp.clip(remainder, a_min=finfo.tiny)
        t = y / remainder

        # inverse of tanh
        t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps)
        return jnp.arctanh(t)
Esempio n. 16
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # compute stick-breaking logdet
        #   t1 -> t1
        #   t2 -> t2 * (1 - abs(t1))
        #   t3 -> t3 * (1 - abs(t1)) * (1 - abs(t2))
        # hence jacobian is triangular and logdet is the sum of the log
        # of the diagonal part of the jacobian
        one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
        eps = jnp.finfo(y.dtype).eps
        one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps)
        # log(remainder) = log1p(remainder - 1)
        stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1)

        tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1)
        return stick_breaking_logdet + tanh_logdet
Esempio n. 17
0
def create_position_ids_from_input_ids(input_ids, padding_idx):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        input_ids: jnp.ndarray
        padding_idx: int

    Returns: jnp.ndarray
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = (input_ids != padding_idx).astype("i4")
    incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
    return incremental_indices.astype("i4") + padding_idx
Esempio n. 18
0
def multinomial(rng, logits):
  """Draws samples from a multinomial distribution.

  Args:
    rng: A jax.random.PRNGKey.
    logits: An array of shape (..., num_categories) containing unnormalized
      log-probabilities.

  Returns:
    An array of shape (...) containing sampled category indices.
  """
  probs = jax.nn.softmax(logits)
  probs = jnp.cumsum(probs, axis=-1)
  a = jax.random.uniform(rng, logits.shape[:-1] + (1,))
  out = jnp.argmin(a > probs, axis=-1)
  return out
Esempio n. 19
0
def _sparsemax(x, axis):
    # get indices of elements in the right axis
    # and reshape to allow broadcasting to other dimensions
    idxs = jnp.arange(x.shape[axis]) + 1
    idxs = reshape_to_broadcast(idxs, x.shape, axis)

    # calculate number of elements that belong to the support
    sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis)
    cum = jnp.cumsum(sorted_x, axis=axis)
    k = jnp.sum(jnp.where(1 + sorted_x * idxs > cum, 1, 0),
                axis=axis,
                keepdims=True)

    # calculate threshold and project to simplex
    threshold = (jnp.take_along_axis(cum, k - 1, axis=axis) - 1) / k
    return jnp.maximum(x - threshold, 0)
Esempio n. 20
0
def mimofoeaf(scope: Scope,
              signal,
              framesize=100,
              w0=0,
              train=False,
              preslicer=lambda x: x,
              foekwargs={},
              mimofn=af.rde,
              mimokwargs={},
              mimoinitargs={}):

    sps = 2
    dims = 2
    tx = signal.t
    # MIMO
    slisig = preslicer(signal)
    auxsig = scope.child(mimoaf,
                         mimofn=mimofn,
                         train=train,
                         mimokwargs=mimokwargs,
                         mimoinitargs=mimoinitargs,
                         name='MIMO4FOE')(slisig)
    y, ty = auxsig  # assume y is continuous in time
    yf = xop.frame(y, framesize, framesize)

    foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs)
    state = scope.variable('af_state', 'framefoeaf', lambda *_:
                           (0., 0, foe_init(w0)), ())
    phi, af_step, af_stats = state.value

    af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats,
                                              yf)
    wp = wf.reshape((-1, dims)).mean(axis=-1)
    w = jnp.interp(
        jnp.arange(y.shape[0] * sps) / sps,
        jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps
    psi = phi + jnp.cumsum(w)
    state.value = (psi[-1], af_step, af_stats)

    # apply FOE to original input signal via linear extrapolation
    psi_ext = jnp.concatenate([
        w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi,
        w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1]
    ])

    signal = signal * jnp.exp(-1j * psi_ext)[:, None]
    return signal
Esempio n. 21
0
def _ravel_list(*leaves):
    leaves_metadata = tree_map(
        lambda l: pytree_metadata(jnp.ravel(l), jnp.shape(l), jnp.size(l),
                                  canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = jnp.cumsum(
        jnp.array((0, ) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [
            jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                        m.shape).astype(m.dtype)
            for i, m in enumerate(leaves_metadata)
        ]

    flat = jnp.concatenate([m.flat for m in leaves_metadata
                            ]) if leaves_metadata else jnp.array([])
    return flat, unravel_list
Esempio n. 22
0
    def Encoding(self, intensities):
        assert jnp.all(intensities >= 0), "Inputs must be non-negative"
        assert intensities.dtype == jnp.float32 or intensities.dtype == jnp.float64, "Intensities must be of type Float."

        # Get shape and size of data.
        shape, size = jnp.shape(intensities), jnp.size(intensities)

        intensities = intensities.reshape(-1)

        time = self.duration // self.dt

        # Compute firing rates in seconds as function of data intensity,
        # accounting for simulation time step.
        rate_p = jnp.zeros(size)
        non_zero = intensities != 0

        rate = index_update(rate_p, index[non_zero],
                            1 / intensities[non_zero] * (1000 / self.dt))
        del rate_p

        # Create Poisson distribution and sample inter-spike intervals
        # (incrementing by 1 to avoid zero intervals).
        intervals_p = random.poisson(key=self.key_x,
                                     lam=rate,
                                     shape=(time,
                                            len(rate))).astype(jnp.float32)

        intervals = index_add(intervals_p, index[:, intensities != 0],
                              (intervals_p[:, intensities != 0] == 0).astype(
                                  jnp.float32))

        del intervals_p

        # Calculate spike times by cumulatively summing over time dimension.

        times_p = jnp.cumsum(intervals, dtype='float32', axis=0)
        times = index_update(times_p, times_p >= time + 1, 0).astype(bool)

        del times_p

        spikes_p = jnp.zeros(shape=(time + 1, size))
        spikes = index_update(spikes_p, index[times], 1)
        spikes = spikes[1:]
        spikes = jnp.transpose(spikes, (1, 0)).astype(jnp.float32)
        return spikes.reshape(time, *shape)
Esempio n. 23
0
def l1_unit_projection(x):
    """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2."""
    # https://dl.acm.org/citation.cfm?id=1390191
    xshape = x.shape
    if len(x.shape) == 1:
        x = x.reshape(1, -1)
    eshape = x.shape
    v = jnp.abs(x.reshape((eshape[0], -1)))
    u = jnp.sort(v, axis=1)
    u = u[:, ::-1]  # descending
    arange = (1 + jnp.arange(eshape[1])).reshape((1, -1))
    usum = (jnp.cumsum(u, axis=1) - 1) / arange
    rho = jnp.max(((u - usum) > 0) * arange - 1, axis=1, keepdims=True)
    thx = jnp.take_along_axis(usum, rho, axis=1)
    w = (v - thx).clip(a_min=0)
    w = jnp.where(jnp.linalg.norm(v, ord=1, axis=1, keepdims=True) > 1, w, v)
    x = w.reshape(eshape) * jnp.sign(x)
    return x.reshape(xshape)
Esempio n. 24
0
def _ravel_list(*leaves, batch_dims):
    leaves_metadata = tree_map(lambda l: pytree_metadata(
        np.reshape(l, (*np.shape(l)[:batch_dims], -1)), np.shape(l), 
        np.prod(np.shape(l)[batch_dims:], dtype='int32'), canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = np.cumsum(np.array((0,) + tuple(d.event_size for d in leaves_metadata)))

    def unravel_list(arr):
        return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.event_size),
                           m.shape[batch_dims:]).astype(m.dtype)
                for i, m in enumerate(leaves_metadata)]

    def unravel_list_batched(arr):
        return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.event_size, axis=batch_dims),
                           m.shape).astype(m.dtype)
                for i, m in enumerate(leaves_metadata)]

    flat = np.concatenate([m.flat for m in leaves_metadata], axis=-1) if leaves_metadata else np.array([])
    return flat, unravel_list, unravel_list_batched
Esempio n. 25
0
  def prune_neighbor_list_dense(R, idx, **kwargs):
    d = partial(metric_sq, **kwargs)
    d = space.map_neighbor(d)

    N = R.shape[0]
    neigh_R = R[idx]
    dR = d(R, neigh_R)

    mask = (dR < cutoff_sq) & (idx < N)
    out_idx = N * jnp.ones(idx.shape, jnp.int32)

    cumsum = jnp.cumsum(mask, axis=1)
    index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1)
    p_index = jnp.arange(idx.shape[0])[:, None]
    out_idx = out_idx.at[p_index, index].set(idx)
    max_occupancy = jnp.max(cumsum[:, -1])

    return out_idx[:, :-1], max_occupancy
Esempio n. 26
0
  def prune_neighbor_list_dense(position: Array, idx: Array, **kwargs) -> Array:
    d = partial(metric_sq, **kwargs)
    d = space.map_neighbor(d)

    N = position.shape[0]
    neigh_position = position[idx]
    dR = d(position, neigh_position)

    mask = (dR < cutoff_sq) & (idx < N)
    out_idx = N * jnp.ones(idx.shape, i32)

    cumsum = jnp.cumsum(mask, axis=1)
    index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1)
    p_index = jnp.arange(idx.shape[0])[:, None]
    out_idx = out_idx.at[p_index, index].set(idx)
    max_occupancy = jnp.max(cumsum[:, -1])

    return out_idx[:, :-1], max_occupancy
Esempio n. 27
0
    def prune_neighbor_list(R, idx, **kwargs):
        d = partial(metric_sq, **kwargs)
        d = vmap(vmap(d, (None, 0)))

        N = R.shape[0]
        neigh_R = R[idx]
        dR = d(R, neigh_R)

        mask = np.logical_and(dR < cutoff_sq, idx < N)
        out_idx = N * np.ones(idx.shape, np.int32)

        cumsum = np.cumsum(mask, axis=1)
        index = np.where(mask, cumsum - 1, idx.shape[1] - 1)
        p_index = np.arange(idx.shape[0])[:, None]
        out_idx = ops.index_update(out_idx, ops.index[p_index, index], idx)
        max_occupancy = np.max(cumsum[:, -1])

        return out_idx, max_occupancy
Esempio n. 28
0
def to_jraph(neighbor: NeighborList, mask: Array = None) -> jraph.GraphsTuple:
    """Convert a sparse neighbor list to a `jraph.GraphsTuple`.

  As in jraph, padding here is accomplished by adding a ficticious graph with a
  single node.

  Args:
    neighbor: A neighbor list that we will convert to the jraph format. Must be
      sparse.
    mask: An optional mask on the edges.

  Returns:
    A `jraph.GraphsTuple` that contains the topology of the neighbor list.
  """
    if not is_sparse(neighbor.format):
        raise ValueError(
            'Cannot convert a dense neighbor list to jraph format. '
            'Please use either NeighborListFormat.Sparse or '
            'NeighborListFormat.OrderedSparse.')

    receivers, senders = neighbor.idx
    N = len(neighbor.reference_position)

    _mask = neighbor_list_mask(neighbor)

    if mask is not None:
        _mask = _mask & mask
        cumsum = jnp.cumsum(_mask)
        index = jnp.where(_mask, cumsum - 1, len(receivers))
        ordered = N * jnp.ones((len(receivers) + 1, ), i32)
        receivers = ordered.at[index].set(receivers)[:-1]
        senders = ordered.at[index].set(senders)[:-1]
        mask = receivers < N

    return jraph.GraphsTuple(
        nodes=None,
        edges=None,
        receivers=receivers,
        senders=senders,
        globals=None,
        n_node=jnp.array([N, 1]),
        n_edge=jnp.array([jnp.sum(_mask), jnp.sum(~_mask)]),
    )
Esempio n. 29
0
def build_par_pack_and_unpack(model):
    """ Build utility functions to pack and unpack paramater pytrees
    for the scipy optimizers. """
    value_flat, value_tree = tree_flatten(model.params)
    section_shapes = [item.shape for item in value_flat]
    section_sizes = jnp.cumsum(jnp.array([item.size for item in value_flat]))

    def par_from_array(arr):
        value_flat = jnp.split(arr, section_sizes)
        value_flat = [x.reshape(s) for x, s in zip(value_flat, section_shapes)]

        params = tree_unflatten(value_tree, value_flat)
        return params

    def array_from_par(params):
        value_flat, value_tree = tree_flatten(params)
        return jnp.concatenate([item.ravel() for item in value_flat])

    return par_from_array, array_from_par
Esempio n. 30
0
def sample_pdf(bins, weights, num_importance, perturbation, rng):
    """Hierarchical sampler.
    Sample `num_importance` rays from `bins` with distribution defined by `weights`.
    Args:
        bins: (num_rays, num_samples - 1) bins to sample from
        weights: (num_rays, num_samples - 2) weights assigned to each sampled color for the coarse model
        num_importance: the number of samples to draw from the distribution
        perturbation: whether to apply jitter on each ray or not
        rng: random key
    Returns:
        samples: (num_rays, num_importance) the sampled rays
    """
    # get pdf
    weights = jnp.clip(weights, 1e-5)  # prevent NaNs
    pdf = weights / jnp.sum(weights, axis=-1, keepdims=True)
    cdf = jnp.cumsum(pdf, axis=-1)
    cdf = jnp.concatenate([jnp.zeros_like(cdf[..., :1]), cdf], axis=-1)

    # take uniform samples
    samples_shape = [*cdf.shape[:-1], num_importance]
    if perturbation:
        uni_samples = random.uniform(rng, shape=samples_shape)
    else:
        uni_samples = jnp.linspace(0.0, 1.0, num_importance)
        uni_samples = jnp.broadcast_to(uni_samples, samples_shape)

    # invert CDF
    idx = jax.vmap(lambda x, y: jnp.searchsorted(x, y, side="right"))(
        cdf, uni_samples)

    below = jnp.maximum(0, idx - 1)
    above = jnp.minimum(cdf.shape[-1] - 1, idx)
    inds_g = jnp.stack([below, above], axis=-1)

    cdf_g = jnp.take_along_axis(cdf[..., None], inds_g, axis=1)
    bins_g = jnp.take_along_axis(bins[..., None], inds_g, axis=1)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    # denom = jnp.where(denom < 1e-5, jnp.ones_like(denom), denom)
    denom = lax.select(denom < 1e-5, jnp.ones_like(denom), denom)
    t = (uni_samples - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
    return samples