def vectorize(self, vectorized): vec_args = [a for a in self.args if a in vectorized] excluded = [i for i, a in enumerate(self.args) if a not in vectorized] if not vec_args: return arg_sig = ",".join(f'({vectorized[a]})' for a in vec_args) sig = f"{arg_sig}->({self.core_shape})" self.fun = jnp.vectorize(self.fun, excluded=excluded, signature=sig) for wrt, d in self.derivatives.items(): wrtsig = (vectorized[var] for var in reversed(wrt)) if self.core_shape: outsig = ','.join((self.core_shape, *wrtsig)) else: outsig = ','.join(wrtsig) dsig = f"{arg_sig}->({outsig})" vecd = jnp.vectorize(d, excluded=excluded, signature=dsig) self.derivatives[wrt] = vecd core_shape = self.core_shape out_core_ndim = len(core_shape.split(',')) if core_shape else 0 self.core_ndim = {a: len(vectorized[a].split(',')) for a in vec_args} self.core_ndim[None] = out_core_ndim self.isvectorized = True
def test_exclude_errors(self): with self.assertRaisesRegex( TypeError, "jax.numpy.vectorize can only exclude"): jnp.vectorize(lambda x: x, excluded={'foo'}) with self.assertRaisesRegex( ValueError, r"excluded=\{-1\} contains negative numbers"): jnp.vectorize(lambda x: x, excluded={-1}) f = jnp.vectorize(lambda x: x, excluded={1}) with self.assertRaisesRegex( ValueError, r"excluded=\{1\} is invalid for 1 argument\(s\)"): f(1.0)
def cosine_similarity( predictions: chex.Array, targets: chex.Array, epsilon: float = 0., ) -> chex.Array: r"""Computes the cosine similarity between targets and predictions. The cosine **similarity** is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm. References: [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) Args: predictions: The predicted vector. targets: Ground truth target vector. epsilon: minimum norm for terms in the denominator of the cosine similarity. Returns: cosine similarity values. """ chex.assert_equal_shape([targets, predictions]) chex.assert_type([targets, predictions], float) # vectorize norm fn, to treat all dimensions except the last as batch dims. batched_norm_fn = jnp.vectorize(utils.safe_norm, signature='(k)->()', excluded={1}) # normalise the last dimension of targets and predictions. unit_targets = targets / jnp.expand_dims(batched_norm_fn(targets, epsilon), axis=-1) unit_predictions = predictions / jnp.expand_dims( batched_norm_fn(predictions, epsilon), axis=-1) # return cosine similarity. return jnp.sum(unit_targets * unit_predictions, axis=-1)
def jax_funcify_Composite(op, vectorize=True, **kwargs): jax_impl = jax_funcify(op.fgraph) def composite(*args): return jax_impl(*args)[0] return jnp.vectorize(composite)
def logpdf(x, mean, cov, allow_singular=None): if allow_singular is not None: raise NotImplementedError( "allow_singular argument of multivariate_normal.logpdf") x, mean, cov = _promote_dtypes_inexact(x, mean, cov) if not mean.shape: return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: n = mean.shape[-1] if not np.shape(cov): y = x - mean return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError( "multivariate_normal.logpdf got incompatible shapes") L = lax.linalg.cholesky(cov) y = jnp.vectorize(partial(lax.linalg.triangular_solve, lower=True, transpose_a=True), signature="(n,n),(n)->(n)")(L, x - mean) return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) - n / 2 * np.log(2 * np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
def forward(self, x: Array) -> Array: """Computes y = f(x).""" self._check_forward_input_shape(x) def unbatched(single_x, matrix, bias): return matrix @ single_x + bias batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m)->(m)") return batched(x, self._matrix, self._bias)
def inverse(self, y: Array) -> Array: """Computes x = f^{-1}(y).""" self._check_inverse_input_shape(y) def unbatched(single_y, matrix, bias): return jnp.linalg.solve(matrix, single_y - bias) batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m)->(m)") return batched(y, self._matrix, self._bias)
def covariance(self) -> Array: """Calculates the covariance.""" probs = self.probs cov_matrix = -self._total_count[..., None, None] * ( probs[..., None, :] * probs[..., :, None]) chex.assert_shape(cov_matrix, probs.shape + self.event_shape) # Missing diagonal term in the covariance matrix. cov_matrix += jnp.vectorize(jnp.diag, signature='(k)->(k,k)')( self._total_count[..., None] * probs) return cov_matrix
def entropy(self) -> Array: """Calculates the Shannon entropy (in nats).""" # The method `_entropy_scalar` does not work when `self.total_count` is an # array (instead of a scalar) or when we jit the function, so we default to # computing the entropy using an alternative method that uses a lax while # loop and does not create intermediate arrays whose shape depends on # `self.total_count`. entropy_fn = jnp.vectorize(self._entropy_scalar_with_lax, signature='(),(k),(k)->()') return entropy_fn(self.total_count, self.probs, self.log_of_probs)
def _vmap_2d(fn: Callable[[float, float, float], float], cov12: np.ndarray, var1: np.ndarray, var2: Optional[np.ndarray], diagonal_batch: bool, diagonal_spatial: bool) -> np.ndarray: """Effectively a "2D vmap" of `fn(cov12, var1, var2)`. Applicable for all possible kernel layouts. Args: fn: scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply. cov12: covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape `(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)` depending on `diagonal_batch`, `diagonal_spatial`, and the number of spatial dimensions. var1: variance tensor (`q11`), has shape `(N1[, X, Y, ...])`. var2: variance tensor (`q22`), has shape `(N1[, X, Y, ...])`. diagonal_batch: `True` if `cov12` has only one batch dimension. diagonal_spatial: `True` if `cov12` has spatial dimensions appearing once (vs twice). Returns: Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same shape as `cov12`. """ batch_ndim = 1 if diagonal_batch else 2 start = 2 - batch_ndim cov_end = batch_ndim if diagonal_spatial else cov12.ndim _cov12 = utils.make_2d(cov12, start, cov_end) var_end = 1 if diagonal_spatial else var1.ndim var1 = var1.reshape(var1.shape[:start] + (-1, ) + var1.shape[var_end:]) var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1, ) + var2.shape[var_end:]) fn = vmap(vmap(np.vectorize(fn), in_axes=(start, None, start), out_axes=start), in_axes=(start, start, None), out_axes=start) out = fn(_cov12, var1, var2) # type: np.ndarray out_shape = (cov12.shape[:start] + cov12.shape[start:cov_end:2] + cov12.shape[start + 1:cov_end:2] + cov12.shape[cov_end:]) out = out.reshape(out_shape) out = utils.zip_axes(out, start, cov_end) return out
def test_bad_inputs(self): matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)') with self.assertRaisesRegex( TypeError, "wrong number of positional arguments"): matmat(jnp.zeros((3, 2))) with self.assertRaisesRegex( ValueError, r"input with shape \(2,\) does not have enough dimensions"): matmat(jnp.zeros((2,)), jnp.zeros((2, 2))) with self.assertRaisesRegex( ValueError, r"inconsistent size for core dimension 'm'"): matmat(jnp.zeros((2, 3)), jnp.zeros((4, 5)))
def covariance(self) -> Array: """Calculates the covariance. Constructs a diagonal matrix with the variance vector as diagonal. Note that TFP would drop leading dimensions in the covariance if `self._scale_diag.ndims < self._loc.ndims`. To keep things simple and predictable, and for consistency with other distributions, in Distrax the `covariance` has shape `batch_shape + (num_dims, num_dims)`. Returns: Diagonal covariance matrix. """ return jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(self.variance())
def get_spin_spherical(transformer, shape, spins): """Returns set of spin-weighted spherical functions. Args: transformer: SpinSphericalFourierTransformer instance. shape: Desired shape (batch, latitude, longitude, spins, channels). spins: Desired spins. Returns: Array of spherical functions and array of their spectral coefficients. """ # Make some arbitrary reproducible complex inputs. batch_size, resolution, _, num_spins, num_channels = shape if len(spins) != num_spins: raise ValueError('len(spins) must match desired shape.') ell_max = sphere_utils.ell_max_from_resolution(resolution) num_coefficients = np_spin_spherical_harmonics.n_coeffs_from_ell_max( ell_max) shape_coefficients = (batch_size, num_spins, num_channels, num_coefficients) # These numbers are chosen arbitrarily, but not randomly, since random # coefficients make for hard to visually interpret functions. Something # simpler like linspace(-1-1j, 1+1j) would have the same phase for all complex # numbers, which is also undesirable. coefficients = (jnp.linspace( -0.5, 0.7 + 0.5j, np.prod(shape_coefficients)).reshape(shape_coefficients)) # Broadcast to_matrix = jnp.vectorize(spin_spherical_harmonics.coefficients_to_matrix, signature='(i)->(j,k)') coefficients = to_matrix(coefficients) # Transpose back to (batch, ell, m, spin, channel) format. coefficients = jnp.transpose(coefficients, (0, 3, 4, 1, 2)) # Coefficients for ell < |spin| are always zero. for i, spin in enumerate(spins): coefficients = coefficients.at[:, :abs(spin), :, i].set(0.0) # Convert to spatial domain. batched_backward_transform = jax.vmap( transformer.swsft_backward_spins_channels, in_axes=(0, None)) sphere = batched_backward_transform(coefficients, spins) return sphere, coefficients
def _get_localized_kernel(self, ell_max, num_channels_in): # We interpolate along ell to obtain all weights from the learnable weights, # hence it doesn't make sense to have more parameters than num_ell. if self.num_filter_params > ell_max + 1: raise ValueError("num_filter_params must be <= ell_max + 1") ell_in = jnp.linspace(0, 1, self.num_filter_params) ell_out = jnp.linspace(0, 1, ell_max + 1) # `vectorize` is over leading dimensions, so we put ell as the last # dimension and transpose it to the first later. learnable_shape = (len(self.spins_in), len(self.spins_out), num_channels_in, self.features, self.num_filter_params) learnable_weights = self.param("kernel", self.initializer, learnable_shape) # `jnp.interp` works on 1D inputs; we vectorize it to interpolate over a # single dimension of n-D inputs. vectorized_interp = jnp.vectorize(jnp.interp, signature="(m),(n),(n)->(m)") weights = vectorized_interp(ell_out, ell_in, learnable_weights) # Make ell the first dimension. return weights.transpose((4, 0, 1, 2, 3))
def __init__(self, connections, input_current=2, V_rest=-0.07, V_reset=-0.07, V_th=-0.054, V_is=-0.08, V_es=0, tm=0.02, ts=0.01, P_max=1, rm_gs=0.5, Rm=0.010): self.V_es = V_es self.V_is = V_is self.V_th = V_th self.V_rest = V_rest self.V_reset = V_reset self.Rm = Rm self.input_current = input_current self.P_max = P_max self.rm_gs = rm_gs self.ts = ts self.tm = tm self.number_of_neurons = len(connections) self.connectivity_matrix = np.fromfunction(np.vectorize(lambda i, j: connections.get(i, {}).get(j, 0)), (self.number_of_neurons, self.number_of_neurons)).T self.neurons = None self.P_synapses = None self.dt = None self.Ie = None
def make_graph(data, save_name): prior = beta.pdf(x, a=data["prior"]["a"], b=data["prior"]["b"]) n_0 = data["likelihood"]["n_0"] n_1 = data["likelihood"]["n_1"] samples = jnp.concatenate([jnp.zeros(n_0), jnp.ones(n_1)]) likelihood_function = jnp.vectorize( lambda p: jnp.exp(bernoulli.logpmf(samples, p).sum())) likelihood = likelihood_function(x) posterior = beta.pdf(x, a=data["posterior"]["a"], b=data["posterior"]["b"]) fig, ax = plt.subplots() axt = ax.twinx() fig1 = ax.plot( x, prior, "k", label=f"prior Beta({data['prior']['a']}, {data['prior']['b']})", linewidth=2.0, ) fig2 = axt.plot(x, likelihood, "r:", label=f"likelihood Bernoulli", linewidth=2.0) fig3 = ax.plot( x, posterior, "b-.", label= f"posterior Beta({data['posterior']['a']}, {data['posterior']['b']})", linewidth=2.0, ) fig_list = fig1 + fig2 + fig3 labels = [fig.get_label() for fig in fig_list] ax.legend(fig_list, labels, loc="upper left", shadow=True) axt.set_ylabel("Likelihood") ax.set_ylabel("Prior/Posterior") ax.set_title(f"$N_0$:{n_0}, $N_1$:{n_1}") pml.savefig(save_name)
def test_wrong_output_type(self): f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k),()') with self.assertRaisesRegex( TypeError, "output must be a tuple"): f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
def test_matmat(self, left_shape, right_shape, result_shape): matmat = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k)') self.assertEqual( matmat(np.zeros(left_shape), np.zeros(right_shape)).shape, result_shape)
def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" fn = jnp.vectorize(_rational_quadratic_spline_inv, signature='(),(n),(n),(n)->(),()') x, logdet = fn(y, self._x_pos, self._y_pos, self._knot_slopes) return x, logdet
def test_vecmat(self, left_shape, right_shape, result_shape): vecvec = jnp.vectorize(jnp.dot, signature='(m),(m)->()') self.assertEqual(vecvec(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape, result_shape)
def test_inconsistent_output_size(self): f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,n)') with self.assertRaisesRegex( ValueError, r"inconsistent size for core dimension 'n'"): f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
def inverse_and_log_det(self, inputs: Array) -> Tuple[Array, Array]: fn = jnp.vectorize(_rational_quadratic_spline_inv, signature="(),(n),(n),(n)->(),()") outputs, log_det = fn(inputs, self.x_pos, self.y_pos, self.knot_slopes) return outputs, log_det
def test_wrong_num_outputs(self): f = jnp.vectorize(lambda *args: args, signature='(),()->(),(),()') with self.assertRaisesRegex( TypeError, "wrong number of output arguments"): f(1, 2)
def ElementwiseNumerical( fn: Callable[[float], float], deg: int, df: Optional[Callable[[float], float]] = None) -> InternalLayer: """Activation function using numerical integration. Supports general activation functions using Gauss-Hermite quadrature. Args: fn: activation function. deg: number of sample points and weights for quadrature. It must be >= 1. We observe for smooth activations deg=25 is a good place to start. For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not recommended (for now use `nt.monte_carlo_kernel_fn`). Due to bivariate integration, compute time and memory scale as O(deg**2) for more precision. See eq (13) in https://mathworld.wolfram.com/Hermite-GaussQuadrature.html for error estimates in the case of 1d Gauss-Hermite quadrature. df: optional, derivative of the activation function (`fn`). If not provided, it is computed by `jax.grad`. Providing analytic derivative can speed up the NTK computations. Returns: `(init_fn, apply_fn, kernel_fn)`. """ warnings.warn( f'Numerical Activation Layer with fn={fn}, deg={deg} used!' 'Note that numerical error is controlled by `deg` and for a given' 'tolerance level, required `deg` will highly be dependent on the choice' 'of `fn`.') quad_points = osp.special.roots_hermite(deg) if df is None: warnings.warn( 'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.' ) df = np.vectorize(grad(fn)) def kernel_fn(k: Kernel) -> Kernel: """Kernel transformation of activation function using quadrature.""" cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk d1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial) d2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial) end_axis = 1 if k.diagonal_spatial else cov1.ndim q11 = utils.interleave_ones(d1, 0, end_axis, True) q22 = utils.interleave_ones(d1 if d2 is None else d2, 0, end_axis, False) def nngp_ntk_fn(nngp, q11, q22, ntk=None): """Simple Gauss-Hermite quadrature routine.""" xs, ws = quad_points grid = np.outer(ws, ws) x = xs.reshape((xs.shape[0], ) + (1, ) * (nngp.ndim + 1)) y = xs.reshape((1, xs.shape[0]) + (1, ) * nngp.ndim) xy_axes = (0, 1) nngp = np.expand_dims(nngp, xy_axes) q11, q22 = np.expand_dims(q11, xy_axes), np.expand_dims(q22, xy_axes) def integrate(f): fvals = f(_sqrt(2 * q11) * x) * f(nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt(2 * (q22 - nngp**2 / q11)) * y) return np.tensordot(grid, fvals, (xy_axes, xy_axes)) / np.pi if ntk is not None: ntk *= integrate(df) nngp = integrate(fn) return nngp, ntk def nngp_fn_diag(nngp): xs, ws = quad_points x = xs.reshape((xs.shape[0], ) + (1, ) * nngp.ndim) x_axes = (0, ) nngp = np.expand_dims(nngp, x_axes) fval = fn(_sqrt(2 * nngp) * x)**2 return np.tensordot(ws, fval, (x_axes, x_axes)) / np.sqrt(np.pi) nngp, ntk = nngp_ntk_fn(nngp, q11, q22, ntk) if k.diagonal_batch and k.diagonal_spatial: cov1 = nngp_fn_diag(cov1) if cov2 is not None: cov2 = nngp_fn_diag(cov2) else: start_axis = 1 if k.diagonal_batch else 0 q11 = utils.interleave_ones(d1, start_axis, end_axis, True) q22 = utils.interleave_ones(d1, start_axis, end_axis, False) cov1, _ = nngp_ntk_fn(cov1, q11, q22) if cov2 is not None: q11 = utils.interleave_ones(d2, start_axis, end_axis, True) q22 = utils.interleave_ones(d2, start_axis, end_axis, False) cov2, _ = nngp_ntk_fn(cov2, q11, q22) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, f'ElementwiseNumerical({fn},deg={deg})', kernel_fn)
def test_wrong_output_shape(self): f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n)') with self.assertRaisesRegex( ValueError, r"output shape \(2, 2\) does not match"): f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
def Elementwise( fn: Optional[Callable[[float], float]] = None, nngp_fn: Optional[Callable[[float, float, float], float]] = None, d_nngp_fn: Optional[Callable[[float, float, float], float]] = None ) -> InternalLayer: """Elementwise application of `fn` using provided `nngp_fn`. Constructs a layer given only scalar-valued nonlinearity / activation `fn` and the 2D integral `nngp_fn`. NTK function is derived automatically in closed form from `nngp_fn`. If you cannot provide the `nngp_fn`, see `nt.stax.ElementwiseNumerical` to use numerical integration or `nt.monte_carlo.monte_carlo_kernel_fn` to use Monte Carlo sampling. If your function is implemented separately (e.g. `nt.stax.Relu` etc) it's best to use the custom implementation, since it uses symbolically simplified expressions that are more precise and numerically stable. Example: >>> fn = jax.scipy.special.erf # type: Callable[[float], float] >>> def nngp_fn(cov12: float, var1: float, var2: float) -> float: >>> prod = (1 + 2 * var1) * (1 + 2 * var2) >>> return np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi >>> # Use autodiff and vectorization to construct the layer: >>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn) >>> # Use custom pre-derived expressions >>> # (should be faster and more numerically stable): >>> _, _, kernel_fn_stax = stax.Erf() >>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2) # usually `True`. Args: fn: a scalar-input/valued function `fn : R -> R`, the activation / nonlinearity. If `None`, invoking the finite width `apply_fn` will raise an exception. nngp_fn: a scalar-valued function `nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]`, where the expectation is over bivariate normal `x1, x2` with variances `var1`, `var2` and covarianve `cov12`. Needed for both NNGP and NTK calculation. If `None`, invoking infinite width `kernel_fn` will raise an exception. d_nngp_fn: an optional scalar-valued function `d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]` with the same `x1, x2` distribution as in `nngp_fn`. If `None`, will be computed using automatic differentiation as `d_nngp_fn = d(nngp_fn)/d(cov12)`, which may lead to worse precision or numerical stability. `nngp_fn` and `d_nngp_fn` are used to derive the closed-form expression for the NTK. Returns: `(init_fn, apply_fn, kernel_fn)`. Raises: NotImplementedError: if a `fn`/`nngp_fn` is not provided, but `apply_fn`/ `kernel_fn` is called respectively. """ if fn is not None: name = fn.__name__ elif nngp_fn is not None: name = nngp_fn.__name__ else: raise ValueError('No finite (`fn`) or infinite (`nngp_fn`) functions ' 'provided, the layer will not do anything.') if nngp_fn is None: kernel_fn = None else: if d_nngp_fn is None: warnings.warn( 'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.' ) d_nngp_fn = np.vectorize(grad(nngp_fn)) def kernel_fn(k: Kernel) -> Kernel: cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk var1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial) var2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial) if ntk is not None: ntk *= _vmap_2d(d_nngp_fn, nngp, var1, var2, False, k.diagonal_spatial) nngp = _vmap_2d(nngp_fn, nngp, var1, var2, False, k.diagonal_spatial) cov1 = _vmap_2d(nngp_fn, cov1, var1, None, k.diagonal_batch, k.diagonal_spatial) if cov2 is not None: cov2 = _vmap_2d(nngp_fn, cov2, var2, None, k.diagonal_batch, k.diagonal_spatial) return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) return _elementwise(fn, name, kernel_fn)
def test_matvec(self, left_shape, right_shape, result_shape): matvec = jnp.vectorize(jnp.dot, signature='(n,m),(m)->(n)') self.assertEqual(matvec(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape, result_shape)
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" fn = jnp.vectorize(_rational_quadratic_spline_fwd, signature='(),(n),(n),(n)->(),()') y, logdet = fn(x, self._x_pos, self._y_pos, self._knot_slopes) return y, logdet
def test_mean(self, shape, result_shape): mean = jnp.vectorize(jnp.mean, signature='(n)->()') self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
def _conv_general_dilated_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray], cts_in: ShapedArray) -> np.ndarray: if idx != 1: raise NotImplementedError(eqn, idx) lhs = invals[1 if idx == 0 else 0] rhs = invals[idx] ndim = cts_in.ndim lhs_spec, rhs_spec, out_spec = eqn.params['dimension_numbers'] precision = eqn.params['precision'] n_groups_f = eqn.params['feature_group_count'] n_groups_b = eqn.params['batch_group_count'] n_channels_in = lhs.shape[lhs_spec[1]] n_batch_in = lhs.shape[lhs_spec[0]] group_size_out = rhs.shape[rhs_spec[0]] // (n_groups_f * n_groups_b) group_size_in = n_channels_in // n_groups_f batch_size_in = n_batch_in // n_groups_b if isinstance(precision, tuple): if precision[0] == precision[1]: precision = precision[0] else: raise NotImplementedError(precision) filter_shape = tuple(rhs.shape[i] for i in range(ndim) if i in rhs_spec[2:]) j = lax.conv_general_dilated_patches( lhs=lhs, filter_shape=filter_shape, window_strides=eqn.params['window_strides'], padding=eqn.params['padding'], lhs_dilation=eqn.params['lhs_dilation'], rhs_dilation=eqn.params['rhs_dilation'], dimension_numbers=eqn.params['dimension_numbers'], precision=precision, preferred_element_type=eqn.params['preferred_element_type']) if n_groups_b > 1: j = np.moveaxis(j, (out_spec[0], out_spec[1]), (-1, -2)) j = j.reshape(j.shape[:-2] + (n_channels_in, *filter_shape, n_groups_b, batch_size_in)) j = np.moveaxis(j, (-1, -2), (-2, -1)) else: j = np.moveaxis(j, out_spec[1], -1) rhs_shape = (n_groups_f, group_size_in) + filter_shape j = j.reshape(j.shape[:ndim - 1] + rhs_shape) j = np.moveaxis(j, (ndim - 1, ndim), (-1, -2)) j = np.vectorize(np.diag, signature='(k)->(k,k)')(j) if n_groups_b > 1: j = np.moveaxis(j, tuple(range(ndim - 2, j.ndim)), [ndim + rhs_spec[1]] + [ndim + i for i in sorted(rhs_spec[2:])] + [out_spec[0], out_spec[1], ndim + rhs_spec[0]]) else: j = np.moveaxis(j, tuple(range(ndim - 1, j.ndim)), [ndim + i for i in sorted(rhs_spec[2:])] + [ndim + rhs_spec[1], out_spec[1], ndim + rhs_spec[0]]) eye = np.eye(group_size_out, dtype=lhs.dtype) eye = np.expand_dims(eye, [ i for i in range(j.ndim) if i not in (out_spec[1], ndim + rhs_spec[0]) ]) j = np.kron(j, eye) return j