def _Kxy(X1: GeodesicTuple, X2: GeodesicTuple): X1 = X1._replace(x=X1.x - self.earth_centre, ref_x=X1.ref_x - self.earth_centre) X2 = X2._replace(x=X2.x - self.earth_centre, ref_x=X2.ref_x - self.earth_centre) X1 = GeodesicTuple(*jnp.broadcast_arrays(*X1)) X2 = GeodesicTuple(*jnp.broadcast_arrays(*X2)) return Kxy(X1, X2)
def __init__( self, *, S0: Optional[Array] = None, w0: Optional[Array] = None, Q: Optional[Array] = None, sigma: Optional[Array] = None, rho: Optional[Array] = None, tau: Optional[Array] = None, ) -> None: if w0 is None: if rho is None: raise ValueError("Either 'w0' or 'rho' is required") w0 = 2 * np.pi / rho if Q is None: if tau is None: raise ValueError("Either 'Q' or 'tau' is required") Q = 0.5 * w0 * tau if S0 is None: if sigma is None: raise ValueError("Either 'S0' or 'sigma' is required") S0 = jnp.square(sigma) / (w0 * Q) self.S0 = S0 self.w0 = w0 self.Q = Q a, b, c, d = SHOTerm.get_parameters(*jnp.broadcast_arrays( *map(jnp.atleast_1d, (self.S0, self.w0, self.Q)))) super().__init__(a=a, b=b, c=c, d=d)
def __init__(self, *, sigma: Array, period: Array, Q0: Array, dQ: Array, f: Array): sigma, period, Q0, dQ, f = jnp.broadcast_arrays( *map(jnp.atleast_1d, (sigma, period, Q0, dQ, f))) amp = jnp.square(sigma) / (1 + f) # One term with a period of period Q1 = 0.5 + Q0 + dQ w1 = 4 * np.pi * Q1 / (period * jnp.sqrt(4 * jnp.square(Q1) - 1)) S1 = amp / (w1 * Q1) # Another term at half the period Q2 = 0.5 + Q0 w2 = 8 * np.pi * Q2 / (period * jnp.sqrt(4 * jnp.square(Q2) - 1)) S2 = f * amp / (w2 * Q2) a1, b1, c1, d1 = SHOTerm.get_parameters(S1, w1, Q1) a2, b2, c2, d2 = SHOTerm.get_parameters(S2, w2, Q2) super().__init__( a=jnp.concatenate((a1, a2)), b=jnp.concatenate((b1, b2)), c=jnp.concatenate((c1, c2)), d=jnp.concatenate((d1, d2)), )
def _ndim_coords_from_arrays(points, ndim): """ Convert a tuple of coordinate arrays to a (..., ndim)-shaped array. """ if isinstance(points, tuple) and len(points) == 1: # handle argument tuple points = points[0] if isinstance(points, tuple): p = jnp.broadcast_arrays(*points) n = len(p) for j in range(1, n): if p[j].shape != p[0].shape: raise ValueError( "coordinate arrays do not have the same shape") points = jnp.empty(p[0].shape + (len(points), ), dtype=float) for j, item in enumerate(p): points[..., j] = item else: points = jnp.asarray(points) if points.ndim == 1: if ndim is None: points = points.reshape(-1, 1) else: points = points.reshape(-1, ndim) return points
def compute_integration_limits(self, x, k, bottom, width): """ Compute the integration limits of the curved layer ionosphere. Args: x: [3] or [N, 3] k: [3] or [N, 3] bottom: scalar height of bottom of layer width: scalar width of width of layer Returns: s_min, s_max with shapes: - scalars if x and k are [3] - arrays of [N] if x or k is [N,3] """ if (len(x.shape) == 2) or (len(k.shape) == 2): x, k = jnp.broadcast_arrays(x, k) return vmap(lambda x, k: self.compute_integration_limits( x, k, bottom, width))(x, k) x0_hat = self.x0 / jnp.linalg.norm(self.x0) bottom_radius2 = jnp.sum(jnp.square(self.x0 + bottom * x0_hat)) top_radius2 = jnp.sum(jnp.square(self.x0 + (bottom + width) * x0_hat)) xk = x @ k x2 = x @ x smin = -xk + jnp.sqrt(xk**2 + (bottom_radius2 - x2)) smax = -xk + jnp.sqrt(xk**2 + (top_radius2 - x2)) return smin, smax
def elemwise(*args): if len(args) > op.nfunc_spec[1]: return jax_variadic_func( jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0 ) else: return jnp_func(*args)
def slicey_slice(dist: Distribution) -> Distribution: # For annoying reasons, it's not possible to have Distribution subtype # NamedTuple even though all distributions are also NamedTuples. But we need # the input here to be Iterable. There's not a clear way to specify that the # input must be a Distribution *and* a NamedTuple AFAICT. params_broadcasted = jp.broadcast_arrays(*cast(Iterable, dist)) return dist.__class__( *[arr[batch_slice] for arr in params_broadcasted])
def _lax_max_taylor_rule(primal_in, series_in): x, y = jnp.broadcast_arrays(*primal_in) xgy = x > y # greater than mask xey = x == y # equal to mask primal_out = lax.select(xgy, x, y) def select_max_and_avg_eq(x_i, y_i): """Select x where x>y or average when x==y""" max_i = lax.select(xgy, x_i, y_i) max_i = lax.select(xey, (x_i + y_i) / 2, max_i) return max_i series_out = [ select_max_and_avg_eq(*jnp.broadcast_arrays(*terms_in)) for terms_in in zip(*series_in) ] return primal_out, series_out
def kepler(mean_anom, ecc): # We're going to apply array broadcasting here since the logic of our op # is much simpler if we require the inputs to all have the same shapes mean_anom_, ecc_ = jnp.broadcast_arrays(mean_anom, ecc) # Then we need to wrap into the range [0, 2*pi) M_mod = jnp.mod(mean_anom_, 2 * np.pi) return _kepler_prim.bind(M_mod, ecc_)
def lse(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = jnp.broadcast_arrays(a, b) a = a + jnp.where(b, jnp.log(jnp.abs(b)), -jnp.inf) b = jnp.sign(b) return jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign)
def gravitational_potential_pairwise_(t, x, x_t): n, *rest = x.shape x1 = x[None] x2 = x[:, None] x1, x2 = np.broadcast_arrays(x1, x2) x1 = x1.reshape(n * n, *rest)[:-1].reshape(n - 1, n + 1, *rest)[:, 1:] x2 = x2.reshape(n * n, *rest)[:-1].reshape(n - 1, n + 1, *rest)[:, 1:] r = vmap(vmap(_dist_one_one, (0, 0), 1), (0, 0), 0)(x1, x2) mm = m[None] * m[:, None] mm = mm.reshape(-1)[:-1].reshape(n - 1, n + 1)[:, 1:] p = mm / r return -.5 * G * p.sum()
def cartesian_product(*arrays) -> jnp.ndarray: """calculate the n-dimensional cartesian product, i.e. create all possible combinations of all elements in a given collection of arrays. Args: *arrays: the arrays to calculate the cartesian product for Returns: the cartesian product. """ ixarrays = jnp.ix_(*arrays) barrays = jnp.broadcast_arrays(*ixarrays) sarrays = jnp.stack(barrays, -1) product = sarrays.reshape(-1, sarrays.shape[-1]) return product
def test_batched_mask(self, mask_batch_shape, input_batch_shape): def create_bijector(mask): return masked_coupling.MaskedCoupling( mask=mask, conditioner=lambda x: x**2, bijector=lambda _: lambda x: 2. * x + 3., event_ndims=2) k1, k2 = jax.random.split(jax.random.PRNGKey(42)) mask = jax.random.choice(k1, jnp.array([True, False]), mask_batch_shape + (5, 6)) bijector = create_bijector(mask) x = jax.random.uniform(k2, input_batch_shape + (5, 6)) y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x) z, logdet_inv = self.variant(bijector.inverse_and_log_det)(x) output_batch_shape = jnp.broadcast_arrays(mask[..., 0, 0], x[..., 0, 0])[0].shape self.assertEqual(y.shape, output_batch_shape + (5, 6)) self.assertEqual(z.shape, output_batch_shape + (5, 6)) self.assertEqual(logdet_fwd.shape, output_batch_shape) self.assertEqual(logdet_inv.shape, output_batch_shape) mask = jnp.broadcast_to(mask, output_batch_shape + (5, 6)).reshape( (-1, 5, 6)) x = jnp.broadcast_to(x, output_batch_shape + (5, 6)).reshape( (-1, 5, 6)) y = y.reshape((-1, 5, 6)) z = z.reshape((-1, 5, 6)) logdet_fwd = logdet_fwd.flatten() logdet_inv = logdet_inv.flatten() for i in range(np.prod(output_batch_shape)): bijector = create_bijector(mask[i]) this_y, this_logdet_fwd = self.variant( bijector.forward_and_log_det)(x[i]) this_z, this_logdet_inv = self.variant( bijector.inverse_and_log_det)(x[i]) np.testing.assert_allclose(this_y, y[i], atol=1e-7) np.testing.assert_allclose(this_z, z[i], atol=1e-7) np.testing.assert_allclose(this_logdet_fwd, logdet_fwd[i], atol=1e-5) np.testing.assert_allclose(this_logdet_inv, logdet_inv[i], atol=1e-5)
def test_batched_parameters(self, matrix_batch_shape, bias_batch_shape, input_batch_shape): prng = hk.PRNGSequence(jax.random.PRNGKey(42)) matrix = jax.random.uniform(next(prng), matrix_batch_shape + (4, 4)) + jnp.eye(4) bias = jax.random.normal(next(prng), bias_batch_shape + (4, )) bijector = LowerUpperTriangularAffine(matrix, bias) x = jax.random.normal(next(prng), input_batch_shape + (4, )) y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x) z, logdet_inv = self.variant(bijector.inverse_and_log_det)(x) output_batch_shape = jnp.broadcast_arrays(matrix[..., 0, 0], bias[..., 0], x[..., 0])[0].shape self.assertEqual(y.shape, output_batch_shape + (4, )) self.assertEqual(z.shape, output_batch_shape + (4, )) self.assertEqual(logdet_fwd.shape, output_batch_shape) self.assertEqual(logdet_inv.shape, output_batch_shape) matrix = jnp.broadcast_to(matrix, output_batch_shape + (4, 4)).reshape( (-1, 4, 4)) bias = jnp.broadcast_to(bias, output_batch_shape + (4, )).reshape( (-1, 4)) x = jnp.broadcast_to(x, output_batch_shape + (4, )).reshape((-1, 4)) y = y.reshape((-1, 4)) z = z.reshape((-1, 4)) logdet_fwd = logdet_fwd.flatten() logdet_inv = logdet_inv.flatten() for i in range(np.prod(output_batch_shape)): bijector = LowerUpperTriangularAffine(matrix[i], bias[i]) this_y, this_logdet_fwd = self.variant( bijector.forward_and_log_det)(x[i]) this_z, this_logdet_inv = self.variant( bijector.inverse_and_log_det)(x[i]) np.testing.assert_allclose(this_y, y[i], atol=9e-3) np.testing.assert_allclose(this_z, z[i], atol=7e-6) np.testing.assert_allclose(this_logdet_fwd, logdet_fwd[i], atol=1e-7) np.testing.assert_allclose(this_logdet_inv, logdet_inv[i], atol=7e-6)
def test_batched_parameters(self, params_batch_shape, input_batch_shape): k1, k2 = jax.random.split(jax.random.PRNGKey(42), 2) num_bins = 4 param_dim = 3 * num_bins + 1 params = jax.random.normal(k1, params_batch_shape + (param_dim, )) bijector = rational_quadratic_spline.RationalQuadraticSpline( params, range_min=0., range_max=1.) x = jax.random.uniform(k2, input_batch_shape) y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x) z, logdet_inv = self.variant(bijector.inverse_and_log_det)(x) output_batch_shape = jnp.broadcast_arrays(params[..., 0], x)[0].shape self.assertEqual(y.shape, output_batch_shape) self.assertEqual(z.shape, output_batch_shape) self.assertEqual(logdet_fwd.shape, output_batch_shape) self.assertEqual(logdet_inv.shape, output_batch_shape) params = jnp.broadcast_to(params, output_batch_shape + (param_dim, )).reshape( (-1, param_dim)) x = jnp.broadcast_to(x, output_batch_shape).flatten() y = y.flatten() z = z.flatten() logdet_fwd = logdet_fwd.flatten() logdet_inv = logdet_inv.flatten() for i in range(np.prod(output_batch_shape)): bijector = rational_quadratic_spline.RationalQuadraticSpline( params[i], range_min=0., range_max=1.) this_y, this_logdet_fwd = self.variant( bijector.forward_and_log_det)(x[i]) this_z, this_logdet_inv = self.variant( bijector.inverse_and_log_det)(x[i]) np.testing.assert_allclose(this_y, y[i], atol=1e-7) np.testing.assert_allclose(this_z, z[i], atol=1e-6) np.testing.assert_allclose(this_logdet_fwd, logdet_fwd[i], atol=1e-5) np.testing.assert_allclose(this_logdet_inv, logdet_inv[i], atol=1e-5)
def test_batched_parameters(self, scale_batch_shape, shift_batch_shape, input_batch_shape): k1, k2, k3 = jax.random.split(jax.random.PRNGKey(42), 3) log_scale = jax.random.normal(k1, scale_batch_shape) shift = jax.random.normal(k2, shift_batch_shape) bijector = scalar_affine.ScalarAffine(shift, log_scale=log_scale) x = jax.random.normal(k3, input_batch_shape) y, logdet_fwd = self.variant(bijector.forward_and_log_det)(x) z, logdet_inv = self.variant(bijector.inverse_and_log_det)(x) output_batch_shape = jnp.broadcast_arrays(log_scale, shift, x)[0].shape self.assertEqual(y.shape, output_batch_shape) self.assertEqual(z.shape, output_batch_shape) self.assertEqual(logdet_fwd.shape, output_batch_shape) self.assertEqual(logdet_inv.shape, output_batch_shape) log_scale = jnp.broadcast_to(log_scale, output_batch_shape).flatten() shift = jnp.broadcast_to(shift, output_batch_shape).flatten() x = jnp.broadcast_to(x, output_batch_shape).flatten() y = y.flatten() z = z.flatten() logdet_fwd = logdet_fwd.flatten() logdet_inv = logdet_inv.flatten() for i in range(np.prod(output_batch_shape)): bijector = scalar_affine.ScalarAffine(shift[i], jnp.exp(log_scale[i])) this_y, this_logdet_fwd = self.variant( bijector.forward_and_log_det)(x[i]) this_z, this_logdet_inv = self.variant( bijector.inverse_and_log_det)(x[i]) np.testing.assert_allclose(this_y, y[i], atol=1e-7) np.testing.assert_allclose(this_z, z[i], atol=1e-5) np.testing.assert_allclose(this_logdet_fwd, logdet_fwd[i], atol=1e-4) np.testing.assert_allclose(this_logdet_inv, logdet_inv[i], atol=1e-4)
def simple_broadcast(self, *args): """ Broadcast a sequence of 1 dimensional arrays. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.simple_broadcast( ... pyhf.tensorlib.astensor([1]), ... pyhf.tensorlib.astensor([2, 3, 4]), ... pyhf.tensorlib.astensor([5, 6, 7])) [DeviceArray([1., 1., 1.], dtype=float64), DeviceArray([2., 3., 4.], dtype=float64), DeviceArray([5., 6., 7.], dtype=float64)] Args: args (Array of Tensors): Sequence of arrays Returns: list of Tensors: The sequence broadcast together. """ return jnp.broadcast_arrays(*args)
def compute_integration_limits_flat(self, x, k, bottom, width): """ Compute the integration limits of the flat layer ionosphere. Args: x: [3] or [N, 3] k: [3] or [N, 3] bottom: scalar height of bottom of layer width: scalar width of width of layer Returns: s_min, s_max with shapes: - scalars if x and k are [3] - arrays of [N] if x or k is [N,3] """ if (len(x.shape) == 2) or (len(k.shape) == 2): x, k = jnp.broadcast_arrays(x, k) return vmap(lambda x, k: self.compute_integration_limits_flat( x, k, bottom, width))(x, k) smin = (bottom - (x[2] - self.x0[2])) / k[2] smax = (bottom + width - (x[2] - self.x0[2])) / k[2] return smin, smax
def normal_sample(*, rng, mean, logvar): mean, logvar = jnp.broadcast_arrays(mean, logvar) return mean + jnp.exp(0.5 * logvar) * jax.random.normal( rng, shape=logvar.shape)
def __init__(self, *, a: Array, b: Array, c: Array, d: Array): self.a, self.b, self.c, self.d = jnp.broadcast_arrays( *map(jnp.atleast_1d, (a, b, c, d))) self.is_real = self.b < 0 self.abs_b = jnp.sqrt(jnp.abs(self.b))
def broadcast_arrays(*args): args = [(a.value if isinstance(a, JaxArray) else a) for a in args] return jnp.broadcast_arrays(args)
def _mean_func(X1: GeodesicTuple): X1 = X1._replace(x=X1.x - self.earth_centre, ref_x=X1.ref_x - self.earth_centre) X1 = GeodesicTuple(*jnp.broadcast_arrays(*X1)) return mean_func(X1)