def loop_body(inputs): rng, parameters, summaries, distances, n_accepted, iteration = \ inputs rng, key = jax.random.split(rng) parameter_samples = self.prior.sample(n_simulations, seed=key) rng, key = jax.random.split(rng) summary_samples = self.compressor( self.simulator(key, parameter_samples)) distance_samples = jax.vmap( lambda target, F: self.distance_measure( summary_samples, target, F))(self.target_summaries, self.F) indices = jax.lax.dynamic_slice( np.arange(n_simulations * max_iterations), [n_simulations * iteration], [n_simulations]) parameters = jax.ops.index_update(parameters, jax.ops.index[indices], parameter_samples) summaries = jax.ops.index_update(summaries, jax.ops.index[indices], summary_samples) distances = jax.ops.index_update(distances, jax.ops.index[:, indices], distance_samples) n_accepted = np.int32(np.less(distances, ϵ).sum(1)) return rng, parameters, summaries, distances, n_accepted, \ iteration + np.int32(1)
def testEnumPromotion(self): class AnEnum(enum.IntEnum): A = 42 B = 101 np.testing.assert_equal(np.array(42), np.array(AnEnum.A)) np.testing.assert_equal(jnp.array(42), jnp.array(AnEnum.A)) np.testing.assert_equal(np.int32(101), np.int32(AnEnum.B)) np.testing.assert_equal(jnp.int32(101), jnp.int32(AnEnum.B))
def init(θ, data): y = data[1] e = errors(θ, x, y) i = np.int32(0) # Iterations counter k = np.int32(1) # Convergence counter C = obj.cost(e) R = obj.regularizer(θ) G = np.float32(β * C + α * R) # Objective function return LMBTrainingState(θ, e, G, C, R, (α, β), μi, τ, i, k)
def testEnumPromotion(self): class AnEnum(enum.IntEnum): A = 42 B = 101 onp.testing.assert_equal(onp.array(42), onp.array(AnEnum.A)) with core.skipping_checks(): # Passing AnEnum.A to np.array fails the type check in bind onp.testing.assert_equal(np.array(42), np.array(AnEnum.A)) onp.testing.assert_equal(onp.int32(101), onp.int32(AnEnum.B)) onp.testing.assert_equal(np.int32(101), np.int32(AnEnum.B))
def rescale_jax(x: ep.JAXTensor, target_shape: List[int]) -> ep.JAXTensor: # img must be in channel_last format # modified according to https://github.com/google/jax/issues/862 import jax.numpy as np img = x.raw resize_rates = (target_shape[1] / x.shape[1], target_shape[2] / x.shape[2]) def interpolate_bilinear( # type: ignore im: np.ndarray, rows: np.ndarray, cols: np.ndarray) -> np.ndarray: # based on http://stackoverflow.com/a/12729229 col_lo = np.floor(cols).astype(int) col_hi = col_lo + 1 row_lo = np.floor(rows).astype(int) row_hi = row_lo + 1 def cclip(cols: np.ndarray) -> np.ndarray: # type: ignore return np.clip(cols, 0, ncols - 1) def rclip(rows: np.ndarray) -> np.ndarray: # type: ignore return np.clip(rows, 0, nrows - 1) nrows, ncols = im.shape[-3:-1] Ia = im[..., rclip(row_lo), cclip(col_lo), :] Ib = im[..., rclip(row_hi), cclip(col_lo), :] Ic = im[..., rclip(row_lo), cclip(col_hi), :] Id = im[..., rclip(row_hi), cclip(col_hi), :] wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1) wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1) wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1) wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1) return wa * Ia + wb * Ib + wc * Ic + wd * Id nrows, ncols = img.shape[-3:-1] deltas = (0.5 / resize_rates[0], 0.5 / resize_rates[1]) rows = np.linspace(deltas[0], nrows - deltas[0], np.int32(resize_rates[0] * nrows)) cols = np.linspace(deltas[1], ncols - deltas[1], np.int32(resize_rates[1] * ncols)) rows_grid, cols_grid = np.meshgrid(rows - 0.5, cols - 0.5, indexing="ij") img_resize_vec = interpolate_bilinear(img, rows_grid.flatten(), cols_grid.flatten()) img_resize = img_resize_vec.reshape(img.shape[:-3] + (len(rows), len(cols)) + img.shape[-1:]) return ep.JAXTensor(img_resize)
def test_pushes_and_pops(self): stack = Stack.create(7, jnp.zeros((), jnp.int32)) stack = stack.push(jnp.int32(7)) self.assertFalse(stack.empty()) stack = stack.push(jnp.int32(8)) self.assertFalse(stack.empty()) x, stack = stack.pop() self.assertFalse(stack.empty()) self.assertEqual(8, x) stack = stack.push(jnp.int32(9)) x, stack = stack.pop() self.assertFalse(stack.empty()) self.assertEqual(9, x) x, stack = stack.pop() self.assertTrue(stack.empty()) self.assertEqual(7, x)
def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(jnp.float32(1.)).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() # int, not float! jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2)) self.assertRaisesRegex( core.JaxprTypeError, r"Variable 'b' inconsistently typed as f32\[\], " r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array( np.ones((2, 3), dtype=jnp.float32)) self.assertRaisesRegex( core.JaxprTypeError, r"Variable 'b' inconsistently typed as f32\[\], " r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", lambda: core.check_jaxpr(jaxpr))
def multinomial_mode( distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] ) -> jnp.DeviceArray: """Calculates the (one-hot) mode of a multinomial distribution. Args: distribution_or_probs: `tfp.distributions.Distribution` | List[tensors]. If the former, it is assumed that it has a `probs` property, and represents a distribution over categories. If the latter, these are taken to be the probabilities of categories directly. In either case, it is assumed that `probs` will be shape (batch_size, dim). Returns: `DeviceArray`, float32, (batch_size, dim). The mode of the distribution - this will be in one-hot form, but contain multiple non-zero entries in the event that more than one probability is joint-highest. """ if isinstance(distribution_or_probs, tfd.Distribution): probs = distribution_or_probs.probs_parameter() else: probs = distribution_or_probs max_prob = jnp.max(probs, axis=1, keepdims=True) mode = jnp.int32(jnp.equal(probs, max_prob)) return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
def update(data, state): y = data[1] H, Je = differentiate(state.θ, state.e, y) # Inner Levenberg-Maquardt update lm_state = LMState(state.θ, state.e, state.G, state.C, state.R, state.μ) lm_cond = partial(_lm_cond, state.G) lm_update = partial(_lm_update, state.θ, H, Je, y, state.Λ) θ, e, G, C, R, μ = while_loop( lm_cond, lm_update, lm_state ) μ = np.where(μ < μmax, μ / μs, μ) μ = np.where(μmin < μ, μ, μmin) # Bayesian hyperparameter learning bl_state = (G, state.Λ, μ, state.τ) bl_update = partial(_bl_update, H, C, R) bl_restart = partial(_bl_restart, state.G) G, Λ, μ, τ = cond( G > state.G, bl_state, bl_restart, bl_state, bl_update ) k = np.where(G >= state.G, state.k + 1, np.int32(1)) return LMBTrainingState(θ, e, G, C, R, Λ, μ, τ, state.i + 1, k)
def hsv_to_rgb(hsv_image): # Adapted from the numpy implementation here: https://gist.github.com/PolarNick239/691387158ff1c41ad73c#file-rgb_to_hsv_np-py input_shape = hsv_image.shape hsv_image = hsv_image.reshape(-1, 3) h, s, v = hsv_image[:, 0], hsv_image[:, 1], hsv_image[:, 2] i = jnp.int32(h * 6.0) f = (h * 6.0) - i p = v * (1.0 - s) q = v * (1.0 - s * f) t = v * (1.0 - s * (1.0 - f)) i = i % 6 rgb_image = jnp.zeros_like(hsv_image) v, t, p, q = v.reshape(-1, 1), t.reshape(-1, 1), p.reshape(-1, 1), q.reshape(-1, 1) i = jnp.tile(i.reshape(-1,1), (1,3)) rgb_image = jnp.where(i==0, jnp.hstack([v, t, p]), rgb_image) rgb_image = jnp.where(i==1, jnp.hstack([q, v, p]), rgb_image) rgb_image = jnp.where(i==2, jnp.hstack([p, v, t]), rgb_image) rgb_image = jnp.where(i==3, jnp.hstack([p, q, v]), rgb_image) rgb_image = jnp.where(i==4, jnp.hstack([t, p, v]), rgb_image) rgb_image = jnp.where(i==5, jnp.hstack([v, p, q]), rgb_image) s = jnp.tile(s.reshape(-1,1), (1,3)) rgb_image = jnp.where(s==0, jnp.hstack([v, v, v]), rgb_image) return rgb_image.reshape(input_shape)
def interpolate1d(x, values, tangents): r"""Perform cubic hermite spline interpolation on a 1D spline. The x coordinates of the spline knots are at [0 : len(values)-1]. Queries outside of the range of the spline are computed using linear extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline for details, where "x" corresponds to `x`, "p" corresponds to `values`, and "m" corresponds to `tangents`. Args: x: A tensor containing the set of values to be used for interpolation into the spline. values: A vector containing the value of each knot of the spline being interpolated into. Must be the same length as `tangents`. tangents: A vector containing the tangent (derivative) of each knot of the spline being interpolated into. Must be the same length as `values` and the same type as `x`. Returns: The result of interpolating along the spline defined by `values`, and `tangents`, using `x` as the query values. Will be the same shape as `x`. """ assert len(values.shape) == 1 assert len(tangents.shape) == 1 assert values.shape[0] == tangents.shape[0] # Find the indices of the knots below and above each x. x_lo = jnp.int32(jnp.floor(jnp.clip(x, 0., values.shape[0] - 2))) x_hi = x_lo + 1 # Compute the relative distance between each `x` and the knot below it. t = x - x_lo # Compute the cubic hermite expansion of `t`. t_sq = t**2 t_cu = t * t_sq h01 = -2 * t_cu + 3 * t_sq h00 = 1 - h01 h11 = t_cu - t_sq h10 = h11 - t_sq + t # Linearly extrapolate above and below the extents of the spline for all # values. value_before = tangents[0] * t + values[0] value_after = tangents[-1] * (t - 1) + values[-1] # Cubically interpolate between the knots below and above each query point. neighbor_values_lo = jnp.take(values, x_lo) neighbor_values_hi = jnp.take(values, x_hi) neighbor_tangents_lo = jnp.take(tangents, x_lo) neighbor_tangents_hi = jnp.take(tangents, x_hi) value_mid = ( neighbor_values_lo * h00 + neighbor_values_hi * h01 + neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11) # Return the interpolated or extrapolated values for each query point, # depending on whether or not the query lies within the span of the spline. return jnp.where(t < 0., value_before, jnp.where(t > 1., value_after, value_mid))
def apply_scatter_op( scatter_agg_op, n: int, values: jnp.ndarray, targets: jnp.ndarray, active: jnp.ndarray = None, ) -> jnp.ndarray: """ Apply given scatter aggregate operation on `values` with their target indices `targets` `scatter_agg_op` is one of `jax.lax.scatter_*`. `n` is the result size, target indices outside the range are dropped. If `active` is given, only `active[i]==True` positions are taken into account. """ if np.issubdtype(values.dtype, np.bool_) and scatter_agg_op in ( jax.lax.scatter_add, jax.lax.scatter_mul, ): values = jnp.int32(values) neutral_value = _op_neutral(scatter_agg_op, values.dtype) # Array of neutral values z = jnp.full((n, ) + values.shape[1:], neutral_value, dtype=values.dtype) if active is not None: targets = jnp.where(active, targets, n + 1) targets = jnp.expand_dims(targets, 1) dims = jax.lax.ScatterDimensionNumbers(tuple(range(1, len(values.shape))), (0, ), (0, )) return scatter_agg_op(z, targets, values, dims, mode="drop")
def epi_demo(edge_beta, gamma, infect, nodes, steps): k = 3 np = NetworkProcess([ epidemics.SIRUpdateOp(), # operations.CountNodeStatesOp(states=3, key="compartment"), # operations.CountNodeTransitionsOp(states=3, key="compartment"), ]) params = {"edge_infection_rate": edge_beta, "recovery_rate": gamma} log.info( f"Network: Barabasi-Albert. n={nodes}, k={k}, cca {nodes*k*2:.2e} directed edges" ) with utils.logged_time(" Creating graph", logger=log): g = nx.random_graphs.barabasi_albert_graph(nodes, k) with utils.logged_time(" Creating state", logger=log): net = Network.from_graph(g) state = np.new_state(net, props=params, seed=42) rng = jax.random.PRNGKey(43) comp = jnp.int32( jax.random.bernoulli(rng, infect / nodes, shape=[nodes])) state.node["compartment"] = comp with utils.logged_time(" Running model", logger=log): t0 = time.time() state2 = np.run(state, steps=steps) state2.block_on_all() t1 = time.time() log.info(np.trace_log()) sps = steps / (t1 - t0) log.info(f"{steps} steps took {t1-t0:.2g} s, {sps:.3g} steps/s, " + f"{sps*state.m:.3g} edge_ops/s, {sps * state.n:.3g} node_ops/s")
def visualize_coord_fix(coords, acc, percentile=99.): """Visualize the "cell" each coordinate lives in, and highlight its edges.""" # Round towards zero. coords_fix = jnp.int32(jnp.fix(coords)) # A very hacky plus-shaped edge detector. coords_fix_pad = jnp.pad(coords_fix, [(1, 1), (1, 1), (0, 0)], 'edge') mask = ((coords_fix == coords_fix_pad[2:, 1:-1, :]) & (coords_fix == coords_fix_pad[:-2, 1:-1, :]) & (coords_fix == coords_fix_pad[1:-1, 2:, :]) & (coords_fix == coords_fix_pad[1:-1, :-2, :])) # Scale according to `acc` and clip to lie in [-1, 1]. max_val = jnp.maximum( 1, math.weighted_percentile(jnp.max(jnp.abs(coords_fix), axis=2), acc, percentile)) coords_fix_unit = jnp.clip(coords_fix / max_val, -1, 1) # The [-1, 1] center cube is gray, and every other integer boundary gets # colored with xyz \propto rgb - gray. Edge pixels are highlighted. return matte( jnp.where(mask, (coords_fix_unit + 1) / 2, 1 - jnp.abs(coords_fix_unit)), acc)
def _generate_partition(decays, decay_distribution, length): # Generates length-sized array split according to decay_distribution. decays = jnp.array(decays) decay_distribution = jnp.array(decay_distribution) multiples = jnp.int32(jnp.floor(decay_distribution * length)) multiples = multiples.at[-1].set(multiples[-1] + length - jnp.sum(multiples)) return jnp.repeat(decays, multiples)
def fill_triangular_inverse(x, upper=False): n = x.shape[-1] n = np.int32(n) m = np.int32((n * (n + 1)) // 2) final_shape = list(x.shape[:-2]) + [m] if upper: initial_elements = x[..., 0, :] triangular_portion = x[..., 1:, :] else: initial_elements = np.flip(x[..., -1, :], axis=-2) triangular_portion = x[..., :-1, :] rotated_triangular_portion = np.flip( np.flip(triangular_portion, axis=-1), axis=-2) consolidated_matrix = triangular_portion + rotated_triangular_portion end_sequence = np.reshape( consolidated_matrix, list(x.shape[:-2]) + [n * (n - 1)]) y = np.concatenate([initial_elements, end_sequence[..., :m - n]], axis=-1) return y
def test_primitive_compilation_cache(self): devices = self.get_devices() x = jax.device_put(jnp.int32(1), devices[1]) with jtu.count_primitive_compiles() as count: y = lax.add(x, x) z = lax.add(y, y) self.assertEqual(count[0], 1) self.assert_committed_to_device(y, devices[1]) self.assert_committed_to_device(z, devices[1])
def fill_triangular(x, upper=False): m = x.shape[-1] if len(x.shape) != 1: raise ValueError("Only handles 1D to 2D transformation, because tril/u") m = np.int32(m) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError('Input right-most shape ({}) does not ' 'correspond to a triangular matrix.'.format(m)) n = np.int32(n) final_shape = list(x.shape[:-1]) + [n, n] if upper: x_list = [x, np.flip(x[..., n:], -1)] else: x_list = [x[..., n:], np.flip(x, -1)] x = np.reshape(np.concatenate(x_list, axis=-1), final_shape) if upper: x = np.triu(x) else: x = np.tril(x) return x
def _for_impl(*args, jaxpr, nsteps, reverse, which_linear): del which_linear discharged_jaxpr, consts = discharge_state(jaxpr, ()) def cond(carry): i, _ = carry return i < nsteps def body(carry): i, state = carry i_ = nsteps - i - 1 if reverse else i next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state) return i + 1, next_state _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args))) return state
def bi_tempered_logistic_loss_fwd(activations, labels, t1, t2, label_smoothing=0.0, num_iters=5): """Forward pass function for bi-tempered logistic loss. Args: activations: A multi-dimensional array with last dimension `num_classes`. labels: An array with shape and dtype as activations. t1: Temperature 1 (< 1.0 for boundedness). t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). label_smoothing: Label smoothing parameter between [0, 1). num_iters: Number of iterations to run the method. Returns: A loss array, residuals. """ num_classes = jnp.int32(labels.shape[-1]) labels = cond( label_smoothing > 0.0, lambda u: # pylint: disable=g-long-lambda (1 - num_classes / (num_classes - 1) * label_smoothing) * u + label_smoothing / (num_classes - 1), lambda u: u, labels) probabilities = tempered_softmax(activations, t2, num_iters) def _tempred_cross_entropy_loss(unused_activations): loss_values = jnp.multiply( labels, log_t(labels + 1e-10, t1) - log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * ( jnp.power(labels, 2.0 - t1) - jnp.power(probabilities, 2.0 - t1)) loss_values = jnp.sum(loss_values, -1) return loss_values loss_values = cond( jnp.logical_and( jnp.less(jnp.abs(t1 - 1.0), 1e-15), jnp.less(jnp.abs(t2 - 1.0), 1e-15)), functools.partial(_cross_entropy_loss, labels=labels), _tempred_cross_entropy_loss, activations) return loss_values, (labels, t1, t2, probabilities)
def rvs(self, nsamps: int = 1) -> np.array: assert np.all(self.prefactors >= 0.) #use residual resampling from SMC theory if nsamps is None: nsamps = len(pop) prop_w = log(self.normalized().prefactors) mult = exp(prop_w + log(nsamps)) count = np.int32(np.floor(mult)) resid = log(mult - count) resid = resid - logsumexp(resid) count = count + onp.random.multinomial(nsamps - count.sum(), exp(resid)) rval = np.repeat(self.inspace_points, count, 0) + self.k.rvs( nsamps, self.inspace_points.shape[1]) return rval
def compute_normalization_binary_search(activations, t, num_iters = 10): """Returns the normalization value for each example (t < 1.0). Args: activations: A multi-dimensional array with last dimension `num_classes`. t: Temperature 2 (< 1.0 for finite support). num_iters: Number of iterations to run the method. Return: An array of same rank as activation with the last dimension being 1. """ mu = jnp.max(activations, -1, keepdims=True) normalized_activations = activations - mu shape_activations = activations.shape effective_dim = jnp.float32( jnp.sum( jnp.int32(normalized_activations > -1.0 / (1.0 - t)), -1, keepdims=True)) shape_partition = list(shape_activations[:-1]) + [1] lower = jnp.zeros(shape_partition) upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition) def cond_fun(carry): _, _, iters = carry return iters < num_iters def body_fun(carry): lower, upper, iters = carry logt_partition = (upper + lower) / 2.0 sum_probs = jnp.sum( exp_t(normalized_activations - logt_partition, t), -1, keepdims=True) update = jnp.float32(sum_probs < 1.0) lower = jnp.reshape(lower * update + (1.0 - update) * logt_partition, shape_partition) upper = jnp.reshape(upper * (1.0 - update) + update * logt_partition, shape_partition) return lower, upper, iters + 1 lower = jnp.zeros(shape_partition) upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition) lower, upper, _ = while_loop(cond_fun, body_fun, (lower, upper, 0)) logt_partition = (upper + lower) / 2.0 return logt_partition + mu
def sample(self, key: jnp.ndarray) -> jnp.ndarray: """Sample from the distribution. Args: key: JAX random key. """ shape = self.p.shape keys = jax.random.split(key, self.n.size) n_sample = self.n.reshape((self.n.size)) p_sample = self.p.reshape((-1, self.p.shape[-1])) samp = [] for n, p, k in zip(n_sample, p_sample, keys): samples = jnp.where( jnp.isnan(n), jnp.nan, jax.random.categorical(k, jnp.log(p), shape=(jnp.int32(n), ))) samp.append(jnp.sum(jax.nn.one_hot(samples, p.shape[-1]), 0)) is_nan = jnp.isnan(self.p) return jnp.where(is_nan, jnp.full(shape, jnp.nan), jnp.stack(samp).reshape(shape))
def sample(self, key: jnp.ndarray) -> jnp.ndarray: """Sample from the distribution. Args: key: JAX random key. """ shape = self.n.shape keys = jax.random.split(key, self.n.size) n_sample = self.n.reshape((self.n.size)) p_sample = self.p.reshape((self.p.size)) samp = [] for n, p, k in zip(n_sample, p_sample, keys): samples = jnp.where( jnp.isnan(n), jnp.nan, jax.random.bernoulli(k, p, shape=(jnp.int32(n), ))) samp.append(jnp.sum(samples)) is_nan = jnp.isnan(self.p) return jnp.where(is_nan, jnp.full(shape, jnp.nan), jnp.stack(samp).reshape(shape))
def compute_shift(lattice, cutoff): """ calculate neccessary repetitions in each direction with reciprocal lattice Args: lattice: takes lattice matrix as 2D-Array , e.g.: jnp.diag(jnp.ones(3)) cutoff: cutoff distance as float Returns: shifts matrix as 2D matrix of shift vectors """ n_repeat = jnp.int32(jnp.ceil(jnp.linalg.norm(jnp.linalg.inv(lattice), axis=0) * cutoff)) # n_repeat = np.int32(np.asarray([1, 1, 1])) print("Repeat", n_repeat) relative_shifts = jnp.array([[el, el2, el3] for el in range(-n_repeat[0], n_repeat[0] + 1, 1) for el2 in range(-n_repeat[1], n_repeat[1] + 1, 1) for el3 in range(-n_repeat[2], n_repeat[2] + 1, 1)]) relative_shifts2 = jnp.where(jnp.where(relative_shifts > 0, relative_shifts-1, relative_shifts) < 0, relative_shifts + 1, jnp.where(relative_shifts > 0, relative_shifts-1, relative_shifts)) shifts = jnp.matmul(jnp.expand_dims(lattice.T, axis=0).repeat(relative_shifts2.shape[0], axis=0), jnp.expand_dims(relative_shifts2, -1)).squeeze() relative_shifts = relative_shifts[jnp.where(np.linalg.norm(shifts, axis=1) < cutoff)] shifts = jnp.matmul(jnp.expand_dims(lattice.T, axis=0).repeat(relative_shifts.shape[0], axis=0), jnp.expand_dims(relative_shifts, -1)).squeeze() return shifts
def test_attributes_create_weights_op_fp( self, weight_range, weight_shape, fp_quant, ): weights = jnp.array( fp32(onp.random.uniform(*weight_range, size=weight_shape))) axis = None if weight_shape[1] == 1 else 0 weights_quant_op = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=fp_quant, axis=axis, half_shift=False)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal( jnp.squeeze(weights_quant_op._scale), jnp.exp2(-jnp.floor(jnp.log2(max_weight)))) self.assertEqual(weights_quant_op._symmetric, True) self.assertIs(weights_quant_op._prec, fp_quant) weights_scaled = (weights * weights_quant_op._scale).astype( weights.dtype) weights_quant_expected = fp_cast.downcast_sat_ftz( weights_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) weights_quant_calculated = weights_quant_op.to_quantized( weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(weights_quant_expected, weights_quant_calculated) # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated # quantized weights are zero. sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1) onp.testing.assert_array_equal( weights_quant_calculated.view(jnp.int32) & sig_mask, jnp.zeros_like(weights))
def prepare_filter_functions(metadata, background_avg): #Convolution kernel kernel_width = np.max( np.array([ np.int32( np.floor(metadata["padded_frame_width"] / metadata["output_frame_width"])), 1 ])) kernel_box = np.ones((kernel_width, kernel_width)) cleanXraw_vmap = jax.vmap(lambda x: cleanXraw(x - background_avg)) cleanXraw_vmap_d1 = jax.vmap(lambda x: cleanXraw(x - background_avg[0])) cleanXraw_vmap_d2 = jax.vmap(lambda x: cleanXraw(x - background_avg[1])) combine_double_exposure_vmapf = jax.vmap( lambda x, y: combine_double_exposure(x, y, metadata[ "double_exp_time_ratio"])) #single and double exposure functions f_cleanframes = jax.jit(lambda x: cleanXraw_vmap(x)) f_cleanframes_d = jax.jit(lambda x, y: combine_double_exposure_vmapf( cleanXraw_vmap_d1(x), cleanXraw_vmap_d2(y))) def f(clean_frame): filtered_frame = filter_frame(clean_frame, kernel_box) centered_rescaled_frame = shift_rescale( filtered_frame, metadata["center_of_mass"], metadata["output_frame_width"], metadata["output_padded_ratio"]) return centered_rescaled_frame process_batch_vmapf = jax.vmap(f) f_all = jax.jit(lambda x: process_batch_vmapf(f_cleanframes(x))) f_all_d = jax.jit(lambda x, y: process_batch_vmapf(f_cleanframes_d(x, y))) return f_all, f_all_d
def testScanTypeErrors(self): """Test typing error messages for scan.""" a = np.arange(5) # Body output not a tuple with self.assertRaisesRegex(TypeError, re.escape("scan body output must be a pair, got ShapedArray(float32[]).")): lax.scan(lambda c, x: np.float32(0.), 0, a) with self.assertRaisesRegex(TypeError, re.escape("scan carry output and input must have same type structure, " "got PyTreeDef(tuple, [*,*,*]) and PyTreeDef(tuple, [*,PyTreeDef(tuple, [*,*])])")): lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), a) with self.assertRaisesRegex(TypeError, re.escape("scan carry output and input must have same type structure, got * and PyTreeDef(None, []).")): lax.scan(lambda c, x: (0, x), None, a) with self.assertRaisesWithLiteralMatch( TypeError, "scan carry output and input must have identical types, got\n" "ShapedArray(int32[])\n" "and\n" "ShapedArray(float32[])."): lax.scan(lambda c, x: (np.int32(0), x), np.float32(1.0), a) with self.assertRaisesRegex(TypeError, re.escape("scan carry output and input must have same type structure, got * and PyTreeDef(tuple, [*,*]).")): lax.scan(lambda c, x: (0, x), (1, 2), np.arange(5))
def testScalarCastInsideJitWorks(self): # jnp.int32(tracer) should work. self.assertEqual(jnp.int32(101), jax.jit(lambda x: jnp.int32(x))(jnp.float32(101.4)))
def testForiLoopErrors(self): """Test typing error messages for while.""" with self.assertRaisesRegex( TypeError, "arguments to fori_loop must have equal types"): lax.fori_loop(onp.int16(0), np.int32(10), (lambda i, c: c), np.float32(7))