def test_compatibility(): # Input: true (y_true) and predicted (y_pred) tensors rng = jax.random.PRNGKey(121) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_true = y_true.astype(dtype=jnp.float32) y_pred = jax.random.uniform(rng, shape=(2, 3)) # cosine_loss using sample_weight huber_loss = elegy.losses.Huber(delta=1.0) huber_loss_tfk = tfk.losses.Huber(delta=1.0) assert jnp.isclose( huber_loss(y_true, y_pred, sample_weight=jnp.array([1, 0])), huber_loss_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])), rtol=0.0001, ) # cosine_loss with reduction method: SUM huber_loss = elegy.losses.Huber(delta=1.0, reduction=elegy.losses.Reduction.SUM) huber_loss_tfk = tfk.losses.Huber(delta=1.0, reduction=tfk.losses.Reduction.SUM) assert jnp.isclose( huber_loss(y_true, y_pred), huber_loss_tfk(y_true, y_pred), rtol=0.0001 ) # cosine_loss with reduction method: NONE huber_loss = elegy.losses.Huber(delta=1.0, reduction=elegy.losses.Reduction.NONE) huber_loss_tfk = tfk.losses.Huber(delta=1.0, reduction=tfk.losses.Reduction.NONE) assert jnp.all( jnp.isclose( huber_loss(y_true, y_pred), huber_loss_tfk(y_true, y_pred), rtol=0.0001 ) )
def test_ellipsoid_params_update(): import pylab as plt N = 4 points = random.normal(random.PRNGKey(43532), shape=(N, 3,)) mu = jnp.mean(points, axis=0) C = jnp.linalg.inv(jnp.sum((points - mu)[:, :, None] * (points - mu)[:, None, :], axis=0)) detC = jnp.linalg.det(C) n = N for i in range(100): x_n = random.normal(random.PRNGKey(i), shape=(3,)) mu_next = mu + (x_n - mu) / (n + 1) C_next, detC_next = rank_one_update_matrix_inv(C, detC, x_n - mu, x_n - mu_next, add=True) n += 1 points = jnp.concatenate([points, x_n[None, :]], axis=0) mu_com = jnp.mean(points, axis=0) C_com = jnp.linalg.inv(jnp.sum((points - mu_com)[:, :, None] * (points - mu_com)[:, None, :], axis=0)) detC_com = jnp.linalg.det(C_com) assert jnp.isclose(detC_com, detC_next) assert jnp.isclose(mu_next, mu_com).all() assert jnp.isclose(C_next, C_com).all() mu, C, detC = mu_next, C_next, detC_next print(detC_next)
def test_basic(): y_true = jnp.array([[0, 1, 0], [0, 0, 1]]) y_pred = jnp.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) # Using 'auto'/'sum_over_batch_size' reduction type. cce = elegy.losses.CategoricalCrossentropy() result = cce(y_true, y_pred) # 1.77 assert jnp.isclose(result, 1.177, rtol=0.01) # Calling with 'sample_weight'. result = cce(y_true, y_pred, sample_weight=jnp.array([0.3, 0.7])) # 0.814 assert jnp.isclose(result, 0.814, rtol=0.01) # Using 'sum' reduction type. cce = elegy.losses.CategoricalCrossentropy( reduction=elegy.losses.Reduction.SUM) result = cce(y_true, y_pred) # 2.354 assert jnp.isclose(result, 2.354, rtol=0.01) # Using 'none' reduction type. cce = elegy.losses.CategoricalCrossentropy( reduction=elegy.losses.Reduction.NONE) result = cce(y_true, y_pred) # [0.0513, 2.303] assert jnp.all(jnp.isclose(result, [0.0513, 2.303], rtol=0.01))
def test_fenics_vjp(): numpy_output, fenics_output, fenics_inputs, tape = fem_eval( solve_fenics, templates, *inputs) g = np.ones_like(numpy_output) jax_grad_tuple = vjp_fem_eval_impl(g, fenics_output, fenics_inputs, tape) check1 = np.isclose(jax_grad_tuple[0], np.asarray(-2.91792642)) check2 = np.isclose(jax_grad_tuple[1], np.asarray(2.43160535)) assert check1 and check2
def test_integral_generation_euler(): tested_scheme = "euler" # Prepare values of means and second moments = E(XY) tested_integrals = [[1], [2], [0]] # [dW1,dW2,dt] tested_labels = ["d_w1", "d_w2", "d_t"] target_means = jnp.array([ pychastic.wiener_integral_moments.E(idx)(1) for idx in tested_integrals ]) target_mean_products = jnp.array([ pychastic.wiener_integral_moments.E2(idx, idy)(1) for (idx, idy) in itertools.product(tested_integrals, tested_integrals) ]) samples_exponent = 14 z_score_cutoff = 5 seed = 0 key = jax.random.PRNGKey(seed) sample_integrals = pychastic.vectorized_I_generation.get_wiener_integrals( key, scheme=tested_scheme, steps=2**samples_exponent, noise_terms=2) sample_integrals = jnp.array([ sample_integrals["d_w"][:, 0], sample_integrals["d_w"][:, 1], jnp.ones_like(sample_integrals["d_w"][:, 0]), ]).T sample_means = jnp.mean(sample_integrals, axis=0) sample_mean_products = jnp.array([ jnp.mean(x * y) for (x, y) in itertools.product(sample_integrals.T, sample_integrals.T) ]) means_close = jnp.isclose(sample_means, target_means, atol=z_score_cutoff * 2**(-samples_exponent / 2)) means_error = sample_means - target_means assert means_close.all(), "Expected values incorrect \n" + str({ label: (bool(flag), float(error)) for (flag, error, label) in zip(means_close, means_error, tested_labels) }) products_close = jnp.isclose( sample_mean_products, target_mean_products, atol=z_score_cutoff * 2**(-samples_exponent / 2), ) products_error = sample_mean_products - target_mean_products assert products_close.all(), "Expected products incorrect \n" + str({ label: (bool(flag), float(error)) for (flag, error, label) in zip( products_close, products_error, itertools.product(tested_labels, tested_labels), ) })
def stop(thresh, initial_state, state): clusters = state[-2] last_cluster = clusters[-2] return ~jnp.isclose( last_cluster, clusters, rtol=thresh ).any() & jnp.isclose( state[-1], initial_state[-1], rtol=thresh )
def test_energy(wavelet_tensors): h, iso, dis = wavelet_tensors s = np.reshape(np.eye(2**3) / 2**3, [2] * 6) for _ in range(20): s = simple_mera.descend(h, s, iso, dis) en = np.trace(np.reshape(s, [2**3, -1]) @ np.reshape(h, [2**3, -1])) assert np.isclose(en, -1.242, rtol=1e-3, atol=1e-3) en = simple_mera.binary_mera_energy(h, s, iso, dis) assert np.isclose(en, -1.242, rtol=1e-3, atol=1e-3)
def test_descend(random_tensors): h, s, iso, dis = random_tensors s = simple_mera.descend(h, s, iso, dis) assert len(s.shape) == 6 D = s.shape[0] smat = np.reshape(s, [D**3] * 2) assert np.isclose(np.trace(smat), 1.0) assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0) spec, _ = np.linalg.eigh(smat) assert np.alltrue(spec >= 0.0)
def test_squared_norm(): x = jnp.linspace(0., 1., 100)[:, None] y = jnp.linspace(1., 2., 50)[:, None] assert jnp.all( jnp.isclose( squared_norm(x, x), jnp.sum(jnp.square(x[:, None, :] - x[None, :, :]), axis=-1))) assert jnp.all( jnp.isclose( squared_norm(x, y), jnp.sum(jnp.square(x[:, None, :] - y[None, :, :]), axis=-1)))
def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas): """For an arbitrary collection of ordered bonds and associated increments `(a, b, delta)`, update `charges` by `charges[a] += delta`, `charges[b] -= delta` Notes ----- * preserves sum(initial_charges) for arbitrary values of bond_idxs or deltas * order within each row of bond_idxs is meaningful `(..., bond_idxs, deltas)` means the opposite of `(..., bond_idxs[:, ::-1], deltas)` * order within the first axis of bond_idxs, deltas is not meaningful `(..., bond_idxs[perm], deltas[perm])` means the same thing for any permutation `perm` """ # apply bond charge corrections incremented = ops.index_add(initial_charges, bond_idxs[:, 0], +deltas) decremented = ops.index_add(incremented, bond_idxs[:, 1], -deltas) final_charges = decremented # make some safety assertions assert bond_idxs.shape[1] == 2 assert len(deltas) == len(bond_idxs) net_charge = jnp.sum(initial_charges) net_charge_is_integral = jnp.isclose(net_charge, jnp.round(net_charge), atol=1e-5) final_net_charge = jnp.sum(final_charges) net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-5) assert net_charge_is_integral assert net_charge_is_unchanged # print some safety warnings directed_bonds = Counter([tuple(b) for b in bond_idxs]) undirected_bonds = Counter([tuple(sorted(b)) for b in bond_idxs]) if max(directed_bonds.values()) > 1: duplicates = [ bond for (bond, count) in directed_bonds.items() if count > 1 ] print(UserWarning(f"Duplicate directed bonds! {duplicates}")) elif max(undirected_bonds.values()) > 1: duplicates = [ bond for (bond, count) in undirected_bonds.items() if count > 1 ] print(UserWarning(f"Duplicate undirected bonds! {duplicates}")) return final_charges
def __assert_rotation(R): if R.ndim != 2: print("R must be a matrix") a, b = R.shape if a != b: print("R must be square") if (not jnp.isclose( jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5) ) or (not jnp.isclose( jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)): print("R is not diagonal")
def test_function(self): y_true = jnp.array([[1.0, 1.0], [0.9, 0.0]]) y_pred = jnp.array([[1.0, 1.0], [1.0, 0.0]]) ## Standard MAPE mape_elegy = elegy.losses.MeanAbsolutePercentageError() mape_tfk = tfk.losses.MeanAbsolutePercentageError() assert jnp.isclose(mape_elegy(y_true, y_pred), mape_tfk(y_true, y_pred), rtol=0.0001) ## MAPE using sample_weight assert jnp.isclose( mape_elegy(y_true, y_pred, sample_weight=jnp.array([1, 0])), mape_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])), rtol=0.0001, ) ## MAPE with reduction method: SUM mape_elegy = elegy.losses.MeanAbsolutePercentageError( reduction=elegy.losses.Reduction.SUM) mape_tfk = tfk.losses.MeanAbsolutePercentageError( reduction=tfk.losses.Reduction.SUM) assert jnp.isclose(mape_elegy(y_true, y_pred), mape_tfk(y_true, y_pred), rtol=0.0001) ## MAPE with reduction method: NONE mape_elegy = elegy.losses.MeanAbsolutePercentageError( reduction=elegy.losses.Reduction.NONE) mape_tfk = tfk.losses.MeanAbsolutePercentageError( reduction=tfk.losses.Reduction.NONE) assert jnp.all( jnp.isclose(mape_elegy(y_true, y_pred), mape_tfk(y_true, y_pred), rtol=0.0001)) ## Prove the loss function rng = jax.random.PRNGKey(42) y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2) y_pred = jax.random.uniform(rng, shape=(2, 3)) y_true = y_true.astype(y_pred.dtype) loss = elegy.losses.mean_percentage_absolute_error(y_true, y_pred) assert loss.shape == (2, ) assert jnp.array_equal( loss, 100 * jnp.mean( jnp.abs((y_pred - y_true) / jnp.maximum(jnp.abs(y_true), utils.EPSILON)), axis=-1, ), )
def get_rotation_pytree(src: Any, dst: Any) -> Any: """ Takes two n-dimensional vectors/Pytree and returns an nxn rotation matrix mapping cjax to dst. Raises Value Error when unsuccessful. """ def __assert_rotation(R): if R.ndim != 2: print("R must be a matrix") a, b = R.shape if a != b: print("R must be square") if (not jnp.isclose( jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5) ) or (not jnp.isclose( jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)): print("R is not diagonal") if not pytree_shape_array_equal(src, dst): print("cjax and dst must be 1-dimensional arrays with the same shape.") x = pytree_normalized(src) y = pytree_normalized(dst) n = len(dst) # compute angle between x and y in their spanning space theta = jnp.arccos(jnp.dot( x, y)) # they are normalized so there is no denominator if jnp.isclose(theta, 0): print("x and y are co-linear") # construct the 2d rotation matrix connecting x to y in their spanning space R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]]) __assert_rotation(R) # get projections onto Span<x,y> and its orthogonal complement u = x v = pytree_normalized(pytree_sub(y, (jnp.dot(u, y) * u))) P = jnp.outer(u, u.T) + jnp.outer( v, v.T) # projection onto 2d space spanned by x and y Q = jnp.eye( n) - P # projection onto the orthogonal complement of Span<x,y> # lift the rotation matrix into the n-dimensional space uv = jnp.hstack((u[:, None], v[:, None])) R = Q + jnp.dot(uv, jnp.dot(R, uv.T)) __assert_rotation(R) if jnp.any(jnp.logical_not(jnp.isclose(jnp.dot(R, x), y, rtol=0.25))): print("Rotation matrix did not work") return R
def test_cumulative_logsumexp(): a = jnp.linspace(-1.,1.,100) v1 = jnp.log(jnp.cumsum(jnp.exp(a))) v2 = cumulative_logsumexp(a) print(v1) print(v2) assert jnp.isclose(v1,v2).all()
def from_pairs(cls, pairs, scale: Scale, normalized=False, interpolate=True): sorted_pairs = sorted([(v["x"], v["density"]) for v in pairs]) xs = np.array([x for (x, density) in sorted_pairs]) densities = np.array([density for (x, density) in sorted_pairs]) if not normalized: xs = scale.normalize_points(xs) densities = scale.normalize_densities(xs, densities) if interpolate: # interpolate ps at target_xs if not (len(xs) == len(constants.target_xs) and np.isclose(xs, constants.target_xs, rtol=1e-04).all()): f = interp1d(xs, densities) densities = f(constants.target_xs) # Make sure AUC is 1 auc = np.sum(densities) / densities.size densities /= auc return cls(constants.target_xs, densities, scale=scale, normalized=True)
def isclose(x, y, atol=1e-8): if isinstance(x, np.ndarray): return np.isclose(x, y, atol=atol).all() elif isinstance(x, list): # return all(np.isclose(x[0], y[0], atol=1e-03).all() for i in range(len(x))) return np.isclose(search_spaces.build(x), search_spaces.build(y), atol=atol).all() elif isinstance(x, tuple) and isinstance(x[0], np.ndarray): return np.isclose(x[0], y[0], atol=atol).all() elif isinstance(x, tuple) and isinstance(x[0], list): return np.isclose(search_spaces.build(x[0]), search_spaces.build(y[0]), atol=atol).all() else: raise ValueError('wrong format')
def conditional_mean(self, x, n): x = x.reshape(-1, self.idx) pi, mu, var = self.condition(x, self.idx) if not np.isclose(np.sum(pi), 1.): pi = self._pi weighted_mean = np.sum(pi * mu.reshape(x.shape[0], -1), 1) return weighted_mean
def test_shuffled_mask_sparsity_empty_twolayer(self): """Tests shuffled mask generation for two layers, for 0% sparsity.""" mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_empty_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all()) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output_twolayer).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape)
def test_ascend(random_tensors): h, s, iso, dis = random_tensors h = simple_mera.ascend(h, s, iso, dis) assert len(h.shape) == 6 D = h.shape[0] hmat = np.reshape(h, [D**3] * 2) assert np.isclose(np.linalg.norm(hmat - np.conj(np.transpose(hmat))), 0.0)
def __init__(self, n: jnp.ndarray, p: jnp.ndarray): """Initializes a multinomial distribution with n trials and probabilities p. n may be multidimensional, in which case it represents multiple multinomial distributions. p has to have the shape of n plus 1 dimension representing the the probabilities of each event. The probabilities in the last dimension have to sum to 1. Args: n: Number of trials. Has to be an integer and non-negative. p: Probabilities of trial successes. Must have same shape as n + 1 additional dimension representing the probabilities. Probabilities have to sum to 1. """ super().__init__() if n.shape != p.shape[:len(n.shape)] or \ len(n.shape) + 1 != len(p.shape): raise ValueError('Shapes of n and p not compatible') # we cannot raise a ValueError here since we get problems with # ConcretizationError during Metropolis-Hastings nans = jnp.full(p.shape, jnp.nan) self.p = jnp.where(jnp.logical_or(p < 0, p > 1), nans, p) self.p = jnp.where(jnp.isclose(jnp.sum(self.p, -1, keepdims=True), 1), self.p, nans) nans = jnp.full(n.shape, jnp.nan) self.n = jnp.where(n <= 0, nans, n)
def test_simple(): p = minjax.Problem() x = onp.array(0., dtype=np.float32) p.add_residual(l2_loss(simple_cost), x) p.solve(verbose=True) print(x) assert np.isclose(x, 5., atol=1e-2)
def test_compatibility(): # Input: true (y_true) and predicted (y_pred) tensors y_true = jnp.array([[0.0, 1.0], [0.0, 0.0]]) y_pred = jnp.array([[0.6, 0.4], [0.4, 0.6]]) # Standard BCE, considering prediction tensor as probabilities bce_elegy = elegy.losses.BinaryCrossentropy() bce_tfk = tfk.losses.BinaryCrossentropy() assert jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001) # Standard BCE, considering prediction tensor as logits y_logits = jnp.log(y_pred) - jnp.log(1 - y_pred) bce_elegy = elegy.losses.BinaryCrossentropy(from_logits=True) bce_tfk = tfk.losses.BinaryCrossentropy(from_logits=True) assert jnp.isclose(bce_elegy(y_true, y_logits), bce_tfk(y_true, y_logits), rtol=0.0001) # BCE using sample_weight bce_elegy = elegy.losses.BinaryCrossentropy() bce_tfk = tfk.losses.BinaryCrossentropy() assert jnp.isclose( bce_elegy(y_true, y_pred, sample_weight=jnp.array([1, 0])), bce_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])), rtol=0.0001, ) # BCE with reduction method: SUM bce_elegy = elegy.losses.BinaryCrossentropy( reduction=elegy.losses.Reduction.SUM) bce_tfk = tfk.losses.BinaryCrossentropy(reduction=tfk.losses.Reduction.SUM) assert jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001) # BCE with reduction method: NONE bce_elegy = elegy.losses.BinaryCrossentropy( reduction=elegy.losses.Reduction.NONE) bce_tfk = tfk.losses.BinaryCrossentropy( reduction=tfk.losses.Reduction.NONE) assert jnp.all( jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001))
def test_sympy2jax(self): x, y, z = sympy.symbols("x y z") cosx = 1.0 * sympy.cos(x) + y key = random.PRNGKey(0) X = random.normal(key, (1000, 2)) true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1] f, params = sympy2jax(cosx, [x, y, z]) self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
def test_logaddexp(): a = jnp.log(1.) b = jnp.log(1.) assert logaddexp(a,b) == jnp.log(2.) a = jnp.log(1.) b = jnp.log(-2.+0j) assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, -1.) a = jnp.log(-1.+0j) b = jnp.log(2. + 0j) assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, 1.) for i in range(100): u = random.uniform(random.PRNGKey(i),shape=(2,))*20. - 10. a = jnp.log(u[0] + 0j) b = jnp.log(u[1] + 0j) assert jnp.isclose(jnp.exp(logaddexp(a,b)).real, u[0] + u[1])
def check_saddle_point(step, y, prev_y, energy, prev_energy): if np.all(np.max(np.abs(y - prev_y), axis=1) < 0.01): logger.debug(f'achieve tolerance with y_hat at {step}th step') return True if np.all(np.isclose(energy, prev_energy)): logger.debug(f'achieve tolerance with energy at {step}th step') return True return False
def safe_norm(a, ord=2, axis=None): if axis is not None: is_zero = jnp.expand_dims(jnp.isclose(jnp.sum(a, axis=axis), 0.), axis=axis) else: is_zero = jnp.ones_like(a, dtype='bool') norm = jnp.linalg.norm(a + jnp.where(is_zero, jnp.ones_like(a) * 1e-5 ** ord, jnp.zeros_like(a)), ord=ord, axis=axis) return norm
def test_fourier_approx(): grid = Grid(lower=(-np.pi, ), upper=(np.pi, ), shape=(512, ), periodic=True) x_scaled = compute_mesh(grid) x = np.pi * x_scaled # Periodic function and its gradient y = jax.vmap(f)(x.flatten()).reshape(x.shape) dy = jax.vmap(jax.grad(f))(x.flatten()).reshape(x.shape) model = SpectralGradientFit(grid) fit = build_fitter(model) evaluate = build_evaluator(model) get_grad = build_grad_evaluator(model) fun = fit(dy) assert np.all(np.isclose(y, evaluate(fun, x_scaled))).item() assert np.all(np.isclose(dy, get_grad(fun, x_scaled))).item() # fig, ax = plt.subplots() # ax.plot(x, dy) # ax.plot(x, get_grad(fun, x_scaled)) # plt.show() # fig, ax = plt.subplots() # ax.plot(x, y) # ax.plot(x, evaluate(fun, x_scaled)) # plt.show() model = SpectralSobolev1Fit(grid) fit = build_fitter(model) evaluate = build_evaluator(model) get_grad = build_grad_evaluator(model) sfun = fit(y, dy) assert np.all(np.isclose(y, evaluate(sfun, x_scaled))).item() assert np.all(np.isclose(dy, get_grad(sfun, x_scaled))).item() assert np.linalg.norm(fun.coefficients - sfun.coefficients) < 1e-8
def _predict_transition_probabilities_jax( X: np.ndarray, W: np.ndarray, softmax_scale: float = 1.0, center_mean: bool = True, scale_by_norm: bool = True, ): if center_mean: # pearson correlation, otherwise cosine W -= W.mean(axis=1)[:, None] X -= X.mean() if scale_by_norm: denom = jnp.linalg.norm(X) * jnp.linalg.norm(W, axis=1) mask = jnp.isclose(denom, 0) denom = jnp.where(jnp.isclose(denom, 0), 1, denom) # essential return _softmax_masked_jax(W.dot(X) / denom, mask, softmax_scale) return _softmax_jax(W.dot(X), softmax_scale)
def test_max(data): model = data.draw(sampled_from([CKY_CRF, DepTree])) struct = model(MaxSemiring) vals, (batch, N) = test_lookup[model]._rand() vals_jax = struct.resize(np.array(vals.numpy())) Ns = np.array([N] * vals_jax.shape[0]) score = struct.sum(vals_jax, Ns) marginals = struct.marginals(vals_jax, Ns) print(marginals) assert np.isclose(score, struct.score(vals_jax, marginals)).all()
def test_api(opt_id): """ Description: verify that all default methods work for this optimizer """ optimizer_class = tigercontrol.optimizer(opt_id) opt = optimizer_class() n, m = opt.get_state_dim(), opt.get_action_dim( ) # get dimensions of system x_0 = opt.reset() # first state x_1 = opt.step(np.zeros(m)) # try step with 0 action assert np.isclose( x_0, opt.reset()) # assert reset return back to original state