def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. # stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted) z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2) stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1) tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1) return stick_breaking_logdet + tanh_logdet
def fetch_minibatch(self): # Generate time + a Brownian motion T = self.T M = self.M N = self.N D = self.D Dt = jnp.zeros((M, N + 1, 1)) # M x (N+1) x 1 DW = jnp.zeros((M, N + 1, D)) # M x (N+1) x D dt = T / N #Dt[:, 1:, :] = dt new_Dt = index_update(Dt, index[:, 1:, :], 1.) #DW[:, 1:, :] = jnp.sqrt(dt) * jnp.random.normal(size=(M, N, D)) new_DW = index_update(DW, index[:, 1:, :], 1.) t = jnp.cumsum(new_Dt, axis=1) # M x (N+1) x 1 W = jnp.cumsum(new_DW, axis=1) # M x (N+1) x D # t = torch.from_numpy(t).float().to(self.device) <- cancel these out so stays as numpy # W = torch.from_numpy(W).float().to(self.device) <- cancel these out so stays as numpy return t, W
def compute_alpha_weights(density, t_vals, dirs): """Helper function for computing alpha compositing weights.""" t_dists = t_vals[Ellipsis, 1:] - t_vals[Ellipsis, :-1] delta = t_dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1) density_delta = density * delta alpha = 1 - jnp.exp(-density_delta) trans = jnp.exp(-jnp.concatenate([ jnp.zeros_like(density_delta[Ellipsis, :1]), jnp.cumsum(density_delta[Ellipsis, :-1], axis=-1) ], axis=-1)) weights = alpha * trans return weights, alpha, trans, delta
def _inverse(self, y): # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim pad_width[-1] = (1, 0) z1m_cumprod_shifted = jnp.pad(z1m_cumprod[..., :-1], pad_width, mode="constant", constant_values=1.) t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt( matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1)) # inverse of tanh x = jnp.log((1 + t) / (1 - t)) / 2 return x
def multinomial(rng, logits): """Draws samples from a multinomial distribution given by logits. Args: rng: A JAX PRNGKey. logits: array with unnormalized log-probabilities in last axis. Returns: Array with sampled categories in last axis. """ probs = jax.nn.softmax(logits) cum_probs = jnp.cumsum(probs, axis=-1) uniform_variates = jax.random.uniform(rng, logits.shape[:-1] + (1, )) return jnp.argmin(uniform_variates > cum_probs, axis=-1)
def piecewise_constant_pdf(key, bins, weights, num_coarse_samples, use_stratified_sampling): """Piecewise-Constant PDF sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, n_bins + 1]. weights: jnp.ndarray(float32), [batch_size, n_bins]. num_coarse_samples: int, the number of samples. use_stratified_sampling: bool, use use_stratified_sampling samples. Returns: z_samples: jnp.ndarray(float32), [batch_size, num_coarse_samples]. """ eps = 1e-5 # Get pdf weights += eps # prevent nans pdf = weights / weights.sum(axis=-1, keepdims=True) cdf = jnp.cumsum(pdf, axis=-1) cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf], axis=-1) # Take uniform samples if use_stratified_sampling: u = random.uniform(key, list(cdf.shape[:-1]) + [num_coarse_samples]) else: u = jnp.linspace(0., 1., num_coarse_samples) u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_coarse_samples]) # Invert CDF. This takes advantage of the fact that `bins` is sorted. mask = (u[..., None, :] >= cdf[..., :, None]) def minmax(x): x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2) x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2) x0 = jnp.minimum(x0, x[..., -2:-1]) x1 = jnp.maximum(x1, x[..., 1:2]) return x0, x1 bins_g0, bins_g1 = minmax(bins) cdf_g0, cdf_g1 = minmax(cdf) denom = (cdf_g1 - cdf_g0) denom = jnp.where(denom < eps, 1., denom) t = (u - cdf_g0) / denom z_samples = bins_g0 + t * (bins_g1 - bins_g0) # Prevent gradient from backprop-ing through samples return lax.stop_gradient(z_samples)
def __call__(self, inputs, prev_state): current_input, return_target = inputs em_state, core_state = prev_state (counter, memories) = em_state if self._apply_core_to_input: current_input, core_state = self._core(current_input, core_state) # Synthetic return for the current state synth_return = jnp.squeeze(self._synthetic_return(current_input), -1) # Current state bias term bias = self._bias(current_input) # Gate computed from current state gate = self._gate(current_input) # When counter > capacity, mask will be all ones mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1) mask = jnp.expand_dims(mask, axis=2) # Synthetic returns for each state in memory past_synth_returns = hk.BatchApply(self._synthetic_return)(memories) # Sum of synthetic returns from previous states sr_sum = jnp.sum(past_synth_returns * mask, axis=1) prediction = jnp.squeeze(sr_sum * gate + bias, -1) sr_loss = self._loss(prediction, return_target) augmented_return = jax.lax.stop_gradient( self._alpha * synth_return + self._beta * return_target) # Write current state to memory _, em_state = self._em(current_input, em_state) if not self._apply_core_to_input: output, core_state = self._core(current_input, core_state) else: output = current_input output = SRCoreWrapperOutput( output=output, synthetic_return=synth_return, augmented_return=augmented_return, sr_loss=sr_loss, ) return output, (em_state, core_state)
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - np.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, (constraints._Real, constraints._RealVector)): return lax.full(size, np.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones( (size[-1], )), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) + 1e-2 elif isinstance(constraint, constraints._CorrMatrix): cholesky = 1e-2 + signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) return np.matmul(cholesky, np.swapaxes(cholesky, -2, -1)) elif isinstance(constraint, constraints._LowerCholesky): return random.uniform(key, size) elif isinstance(constraint, constraints._PositiveDefinite): return random.normal(key, size) elif isinstance(constraint, constraints._OrderedVector): x = np.cumsum(random.exponential(key, size), -1) return x[..., ::-1] else: raise NotImplementedError('{} not implemented.'.format(constraint))
def rtrun_direct(dtau, S): """Radiative Transfer using direct integration. Note: Use dtau/mu instead of dtau when you want to use non-unity, where mu=cos(theta) Args: dtau: opacity matrix S: source matrix [N_layer, N_nus] Returns: flux in the unit of [erg/cm2/s/cm-1] if using piBarr as a source function. """ taupmu = jnp.cumsum(dtau, axis=0) return jnp.sum(S * jnp.exp(-taupmu) * dtau, axis=0)
def _entmax15(x, axis): x = x / 2 # get indices of elements in the right axis # and reshape to allow broadcasting to other dimensions idxs = jnp.arange(x.shape[axis]) + 1 idxs = reshape_to_broadcast(idxs, x.shape, axis) # calculate number of elements that belong to the support sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis) cum_x = jnp.cumsum(sorted_x, axis=axis) cum_x_sq = jnp.cumsum(sorted_x**2, axis=axis) mean = cum_x / idxs var = cum_x_sq - (mean**2) * idxs delta = (1 - var) / idxs delta = jnp.maximum(delta, 0) # TODO: understand why we need this thresholds = mean - jnp.sqrt(delta) k = jnp.sum(jnp.where(thresholds <= sorted_x, 1, 0), axis=axis, keepdims=True) # calculate threshold and project to simplex threshold = jnp.take_along_axis(thresholds, k - 1, axis=axis) return jnp.maximum(x - threshold, 0)**2
def generate_data(): T = 1000 tec = jnp.cumsum(15. * random.normal(random.PRNGKey(0), shape=(T, ))) TEC_CONV = -8.4479745e6 # mTECU/Hz freqs = jnp.linspace(121e6, 168e6, 24) phase = tec[:, None] / freqs * TEC_CONV Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=1) Y_obs = Y + 0.75 * random.normal(random.PRNGKey(1), shape=Y.shape) # Y_obs[500:550:2, :] += 3. * random.normal(random.PRNGKey(1),shape=Y[500:550:2, :].shape) Sigma = 0.5**2 * jnp.eye(48) Omega = jnp.diag(jnp.array([30.]))**2 mu0 = jnp.zeros(1) Gamma0 = jnp.diag(jnp.array([200.]))**2 amp = jnp.ones_like(phase) return Gamma0, Omega, Sigma, T, Y_obs, amp, mu0, tec, freqs
def _csr_fromdense_impl(mat, *, nnz, index_dtype): mat = jnp.asarray(mat) assert mat.ndim == 2 m = mat.shape[0] row, col = jnp.nonzero(mat, size=nnz) data = mat[row, col] true_nonzeros = jnp.arange(nnz) < (mat != 0).sum() data = jnp.where(true_nonzeros, data, 0) row = jnp.where(true_nonzeros, row, m) indices = col.astype(index_dtype) indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set( jnp.cumsum(jnp.bincount(row, length=m))) return data, indices, indptr
def spline_unconstrained_transform(thetax: jnp.ndarray, thetay: jnp.ndarray, thetad: jnp.ndarray) -> jnp.ndarray: """Transform the unconstrained parameters of the spline transform into their constrained counterparts. Args: thetax: Unconstrained x-coordinates of the spline intervals. thetay: Unconstrained y-coordinates of the spline intervals. thetad: Unconstrained derivatives at internal points. Returns: xk: The x-coordinates of the intervals on which the rational quadratics are defined. yk: The y-coordinates of the destination intervals of the rational quadratic transforms. delta: Derivatives at internal points. """ xk = jnp.atleast_2d(jnp.cumsum(2 * nn.softmax(thetax), axis=-1) - 1.) xk = jnp.hstack((-jnp.ones((xk.shape[0], 1)), xk)) yk = jnp.atleast_2d(jnp.cumsum(2 * nn.softmax(thetay), axis=-1) - 1.) yk = jnp.hstack((-jnp.ones((yk.shape[0], 1)), yk)) delta = nn.softplus(thetad) return jnp.squeeze(xk), jnp.squeeze(yk), jnp.squeeze(delta)
def _ravel_list(*leaves): leaves_metadata = tree_map( lambda l: pytree_metadata(np.ravel(l), np.shape(l), np.size(l), lax.dtype(l)), leaves) leaves_idx = np.cumsum( np.array((0, ) + tuple(d.size for d in leaves_metadata))) def unravel_list(arr): return [ np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata) ] return np.concatenate([m.flat for m in leaves_metadata]), unravel_list
def _inverse(self, y): # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) finfo = jnp.finfo(y.dtype) remainder = jnp.clip(remainder, a_min=finfo.tiny) t = y / remainder # inverse of tanh t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps) return jnp.arctanh(t)
def log_abs_det_jacobian(self, x, y, intermediates=None): # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) # t3 -> t3 * (1 - abs(t1)) * (1 - abs(t2)) # hence jacobian is triangular and logdet is the sum of the log # of the diagonal part of the jacobian one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) eps = jnp.finfo(y.dtype).eps one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps) # log(remainder) = log1p(remainder - 1) stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1) tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1) return stick_breaking_logdet + tanh_logdet
def create_position_ids_from_input_ids(input_ids, padding_idx): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. Args: input_ids: jnp.ndarray padding_idx: int Returns: jnp.ndarray """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. mask = (input_ids != padding_idx).astype("i4") incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask return incremental_indices.astype("i4") + padding_idx
def multinomial(rng, logits): """Draws samples from a multinomial distribution. Args: rng: A jax.random.PRNGKey. logits: An array of shape (..., num_categories) containing unnormalized log-probabilities. Returns: An array of shape (...) containing sampled category indices. """ probs = jax.nn.softmax(logits) probs = jnp.cumsum(probs, axis=-1) a = jax.random.uniform(rng, logits.shape[:-1] + (1,)) out = jnp.argmin(a > probs, axis=-1) return out
def _sparsemax(x, axis): # get indices of elements in the right axis # and reshape to allow broadcasting to other dimensions idxs = jnp.arange(x.shape[axis]) + 1 idxs = reshape_to_broadcast(idxs, x.shape, axis) # calculate number of elements that belong to the support sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis) cum = jnp.cumsum(sorted_x, axis=axis) k = jnp.sum(jnp.where(1 + sorted_x * idxs > cum, 1, 0), axis=axis, keepdims=True) # calculate threshold and project to simplex threshold = (jnp.take_along_axis(cum, k - 1, axis=axis) - 1) / k return jnp.maximum(x - threshold, 0)
def mimofoeaf(scope: Scope, signal, framesize=100, w0=0, train=False, preslicer=lambda x: x, foekwargs={}, mimofn=af.rde, mimokwargs={}, mimoinitargs={}): sps = 2 dims = 2 tx = signal.t # MIMO slisig = preslicer(signal) auxsig = scope.child(mimoaf, mimofn=mimofn, train=train, mimokwargs=mimokwargs, mimoinitargs=mimoinitargs, name='MIMO4FOE')(slisig) y, ty = auxsig # assume y is continuous in time yf = xop.frame(y, framesize, framesize) foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs) state = scope.variable('af_state', 'framefoeaf', lambda *_: (0., 0, foe_init(w0)), ()) phi, af_step, af_stats = state.value af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats, yf) wp = wf.reshape((-1, dims)).mean(axis=-1) w = jnp.interp( jnp.arange(y.shape[0] * sps) / sps, jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps psi = phi + jnp.cumsum(w) state.value = (psi[-1], af_step, af_stats) # apply FOE to original input signal via linear extrapolation psi_ext = jnp.concatenate([ w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi, w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1] ]) signal = signal * jnp.exp(-1j * psi_ext)[:, None] return signal
def _ravel_list(*leaves): leaves_metadata = tree_map( lambda l: pytree_metadata(jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves) leaves_idx = jnp.cumsum( jnp.array((0, ) + tuple(d.size for d in leaves_metadata))) def unravel_list(arr): return [ jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata) ] flat = jnp.concatenate([m.flat for m in leaves_metadata ]) if leaves_metadata else jnp.array([]) return flat, unravel_list
def Encoding(self, intensities): assert jnp.all(intensities >= 0), "Inputs must be non-negative" assert intensities.dtype == jnp.float32 or intensities.dtype == jnp.float64, "Intensities must be of type Float." # Get shape and size of data. shape, size = jnp.shape(intensities), jnp.size(intensities) intensities = intensities.reshape(-1) time = self.duration // self.dt # Compute firing rates in seconds as function of data intensity, # accounting for simulation time step. rate_p = jnp.zeros(size) non_zero = intensities != 0 rate = index_update(rate_p, index[non_zero], 1 / intensities[non_zero] * (1000 / self.dt)) del rate_p # Create Poisson distribution and sample inter-spike intervals # (incrementing by 1 to avoid zero intervals). intervals_p = random.poisson(key=self.key_x, lam=rate, shape=(time, len(rate))).astype(jnp.float32) intervals = index_add(intervals_p, index[:, intensities != 0], (intervals_p[:, intensities != 0] == 0).astype( jnp.float32)) del intervals_p # Calculate spike times by cumulatively summing over time dimension. times_p = jnp.cumsum(intervals, dtype='float32', axis=0) times = index_update(times_p, times_p >= time + 1, 0).astype(bool) del times_p spikes_p = jnp.zeros(shape=(time + 1, size)) spikes = index_update(spikes_p, index[times], 1) spikes = spikes[1:] spikes = jnp.transpose(spikes, (1, 0)).astype(jnp.float32) return spikes.reshape(time, *shape)
def l1_unit_projection(x): """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2.""" # https://dl.acm.org/citation.cfm?id=1390191 xshape = x.shape if len(x.shape) == 1: x = x.reshape(1, -1) eshape = x.shape v = jnp.abs(x.reshape((eshape[0], -1))) u = jnp.sort(v, axis=1) u = u[:, ::-1] # descending arange = (1 + jnp.arange(eshape[1])).reshape((1, -1)) usum = (jnp.cumsum(u, axis=1) - 1) / arange rho = jnp.max(((u - usum) > 0) * arange - 1, axis=1, keepdims=True) thx = jnp.take_along_axis(usum, rho, axis=1) w = (v - thx).clip(a_min=0) w = jnp.where(jnp.linalg.norm(v, ord=1, axis=1, keepdims=True) > 1, w, v) x = w.reshape(eshape) * jnp.sign(x) return x.reshape(xshape)
def _ravel_list(*leaves, batch_dims): leaves_metadata = tree_map(lambda l: pytree_metadata( np.reshape(l, (*np.shape(l)[:batch_dims], -1)), np.shape(l), np.prod(np.shape(l)[batch_dims:], dtype='int32'), canonicalize_dtype(lax.dtype(l))), leaves) leaves_idx = np.cumsum(np.array((0,) + tuple(d.event_size for d in leaves_metadata))) def unravel_list(arr): return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.event_size), m.shape[batch_dims:]).astype(m.dtype) for i, m in enumerate(leaves_metadata)] def unravel_list_batched(arr): return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.event_size, axis=batch_dims), m.shape).astype(m.dtype) for i, m in enumerate(leaves_metadata)] flat = np.concatenate([m.flat for m in leaves_metadata], axis=-1) if leaves_metadata else np.array([]) return flat, unravel_list, unravel_list_batched
def prune_neighbor_list_dense(R, idx, **kwargs): d = partial(metric_sq, **kwargs) d = space.map_neighbor(d) N = R.shape[0] neigh_R = R[idx] dR = d(R, neigh_R) mask = (dR < cutoff_sq) & (idx < N) out_idx = N * jnp.ones(idx.shape, jnp.int32) cumsum = jnp.cumsum(mask, axis=1) index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = jnp.arange(idx.shape[0])[:, None] out_idx = out_idx.at[p_index, index].set(idx) max_occupancy = jnp.max(cumsum[:, -1]) return out_idx[:, :-1], max_occupancy
def prune_neighbor_list_dense(position: Array, idx: Array, **kwargs) -> Array: d = partial(metric_sq, **kwargs) d = space.map_neighbor(d) N = position.shape[0] neigh_position = position[idx] dR = d(position, neigh_position) mask = (dR < cutoff_sq) & (idx < N) out_idx = N * jnp.ones(idx.shape, i32) cumsum = jnp.cumsum(mask, axis=1) index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = jnp.arange(idx.shape[0])[:, None] out_idx = out_idx.at[p_index, index].set(idx) max_occupancy = jnp.max(cumsum[:, -1]) return out_idx[:, :-1], max_occupancy
def prune_neighbor_list(R, idx, **kwargs): d = partial(metric_sq, **kwargs) d = vmap(vmap(d, (None, 0))) N = R.shape[0] neigh_R = R[idx] dR = d(R, neigh_R) mask = np.logical_and(dR < cutoff_sq, idx < N) out_idx = N * np.ones(idx.shape, np.int32) cumsum = np.cumsum(mask, axis=1) index = np.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = np.arange(idx.shape[0])[:, None] out_idx = ops.index_update(out_idx, ops.index[p_index, index], idx) max_occupancy = np.max(cumsum[:, -1]) return out_idx, max_occupancy
def to_jraph(neighbor: NeighborList, mask: Array = None) -> jraph.GraphsTuple: """Convert a sparse neighbor list to a `jraph.GraphsTuple`. As in jraph, padding here is accomplished by adding a ficticious graph with a single node. Args: neighbor: A neighbor list that we will convert to the jraph format. Must be sparse. mask: An optional mask on the edges. Returns: A `jraph.GraphsTuple` that contains the topology of the neighbor list. """ if not is_sparse(neighbor.format): raise ValueError( 'Cannot convert a dense neighbor list to jraph format. ' 'Please use either NeighborListFormat.Sparse or ' 'NeighborListFormat.OrderedSparse.') receivers, senders = neighbor.idx N = len(neighbor.reference_position) _mask = neighbor_list_mask(neighbor) if mask is not None: _mask = _mask & mask cumsum = jnp.cumsum(_mask) index = jnp.where(_mask, cumsum - 1, len(receivers)) ordered = N * jnp.ones((len(receivers) + 1, ), i32) receivers = ordered.at[index].set(receivers)[:-1] senders = ordered.at[index].set(senders)[:-1] mask = receivers < N return jraph.GraphsTuple( nodes=None, edges=None, receivers=receivers, senders=senders, globals=None, n_node=jnp.array([N, 1]), n_edge=jnp.array([jnp.sum(_mask), jnp.sum(~_mask)]), )
def build_par_pack_and_unpack(model): """ Build utility functions to pack and unpack paramater pytrees for the scipy optimizers. """ value_flat, value_tree = tree_flatten(model.params) section_shapes = [item.shape for item in value_flat] section_sizes = jnp.cumsum(jnp.array([item.size for item in value_flat])) def par_from_array(arr): value_flat = jnp.split(arr, section_sizes) value_flat = [x.reshape(s) for x, s in zip(value_flat, section_shapes)] params = tree_unflatten(value_tree, value_flat) return params def array_from_par(params): value_flat, value_tree = tree_flatten(params) return jnp.concatenate([item.ravel() for item in value_flat]) return par_from_array, array_from_par
def sample_pdf(bins, weights, num_importance, perturbation, rng): """Hierarchical sampler. Sample `num_importance` rays from `bins` with distribution defined by `weights`. Args: bins: (num_rays, num_samples - 1) bins to sample from weights: (num_rays, num_samples - 2) weights assigned to each sampled color for the coarse model num_importance: the number of samples to draw from the distribution perturbation: whether to apply jitter on each ray or not rng: random key Returns: samples: (num_rays, num_importance) the sampled rays """ # get pdf weights = jnp.clip(weights, 1e-5) # prevent NaNs pdf = weights / jnp.sum(weights, axis=-1, keepdims=True) cdf = jnp.cumsum(pdf, axis=-1) cdf = jnp.concatenate([jnp.zeros_like(cdf[..., :1]), cdf], axis=-1) # take uniform samples samples_shape = [*cdf.shape[:-1], num_importance] if perturbation: uni_samples = random.uniform(rng, shape=samples_shape) else: uni_samples = jnp.linspace(0.0, 1.0, num_importance) uni_samples = jnp.broadcast_to(uni_samples, samples_shape) # invert CDF idx = jax.vmap(lambda x, y: jnp.searchsorted(x, y, side="right"))( cdf, uni_samples) below = jnp.maximum(0, idx - 1) above = jnp.minimum(cdf.shape[-1] - 1, idx) inds_g = jnp.stack([below, above], axis=-1) cdf_g = jnp.take_along_axis(cdf[..., None], inds_g, axis=1) bins_g = jnp.take_along_axis(bins[..., None], inds_g, axis=1) denom = cdf_g[..., 1] - cdf_g[..., 0] # denom = jnp.where(denom < 1e-5, jnp.ones_like(denom), denom) denom = lax.select(denom < 1e-5, jnp.ones_like(denom), denom) t = (uni_samples - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples