Esempio n. 1
0
def test_compatibility():
    # Input:  true (y_true) and predicted (y_pred) tensors
    rng = jax.random.PRNGKey(121)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_true = y_true.astype(dtype=jnp.float32)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    # cosine_loss using sample_weight
    huber_loss = elegy.losses.Huber(delta=1.0)
    huber_loss_tfk = tfk.losses.Huber(delta=1.0)

    assert jnp.isclose(
        huber_loss(y_true, y_pred, sample_weight=jnp.array([1, 0])),
        huber_loss_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])),
        rtol=0.0001,
    )

    # cosine_loss with reduction method: SUM
    huber_loss = elegy.losses.Huber(delta=1.0, reduction=elegy.losses.Reduction.SUM)
    huber_loss_tfk = tfk.losses.Huber(delta=1.0, reduction=tfk.losses.Reduction.SUM)
    assert jnp.isclose(
        huber_loss(y_true, y_pred), huber_loss_tfk(y_true, y_pred), rtol=0.0001
    )

    # cosine_loss with reduction method: NONE
    huber_loss = elegy.losses.Huber(delta=1.0, reduction=elegy.losses.Reduction.NONE)
    huber_loss_tfk = tfk.losses.Huber(delta=1.0, reduction=tfk.losses.Reduction.NONE)
    assert jnp.all(
        jnp.isclose(
            huber_loss(y_true, y_pred), huber_loss_tfk(y_true, y_pred), rtol=0.0001
        )
    )
Esempio n. 2
0
def test_ellipsoid_params_update():
    import pylab as plt
    N = 4
    points = random.normal(random.PRNGKey(43532), shape=(N, 3,))
    mu = jnp.mean(points, axis=0)
    C = jnp.linalg.inv(jnp.sum((points - mu)[:, :, None] * (points - mu)[:, None, :], axis=0))
    detC = jnp.linalg.det(C)
    n = N
    for i in range(100):
        x_n = random.normal(random.PRNGKey(i), shape=(3,))
        mu_next = mu + (x_n - mu) / (n + 1)
        C_next, detC_next = rank_one_update_matrix_inv(C,
                                                       detC,
                                                       x_n - mu,
                                                       x_n - mu_next,
                                                       add=True)
        n += 1
        points = jnp.concatenate([points, x_n[None, :]], axis=0)
        mu_com = jnp.mean(points, axis=0)
        C_com = jnp.linalg.inv(jnp.sum((points - mu_com)[:, :, None] * (points - mu_com)[:, None, :], axis=0))
        detC_com = jnp.linalg.det(C_com)
        assert jnp.isclose(detC_com, detC_next)
        assert jnp.isclose(mu_next, mu_com).all()
        assert jnp.isclose(C_next, C_com).all()
        mu, C, detC = mu_next, C_next, detC_next
        print(detC_next)
Esempio n. 3
0
def test_basic():

    y_true = jnp.array([[0, 1, 0], [0, 0, 1]])
    y_pred = jnp.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])

    # Using 'auto'/'sum_over_batch_size' reduction type.
    cce = elegy.losses.CategoricalCrossentropy()
    result = cce(y_true, y_pred)  # 1.77
    assert jnp.isclose(result, 1.177, rtol=0.01)

    # Calling with 'sample_weight'.
    result = cce(y_true, y_pred, sample_weight=jnp.array([0.3, 0.7]))  # 0.814
    assert jnp.isclose(result, 0.814, rtol=0.01)

    # Using 'sum' reduction type.
    cce = elegy.losses.CategoricalCrossentropy(
        reduction=elegy.losses.Reduction.SUM)
    result = cce(y_true, y_pred)  # 2.354
    assert jnp.isclose(result, 2.354, rtol=0.01)

    # Using 'none' reduction type.
    cce = elegy.losses.CategoricalCrossentropy(
        reduction=elegy.losses.Reduction.NONE)
    result = cce(y_true, y_pred)  # [0.0513, 2.303]
    assert jnp.all(jnp.isclose(result, [0.0513, 2.303], rtol=0.01))
Esempio n. 4
0
def test_fenics_vjp():
    numpy_output, fenics_output, fenics_inputs, tape = fem_eval(
        solve_fenics, templates, *inputs)
    g = np.ones_like(numpy_output)
    jax_grad_tuple = vjp_fem_eval_impl(g, fenics_output, fenics_inputs, tape)
    check1 = np.isclose(jax_grad_tuple[0], np.asarray(-2.91792642))
    check2 = np.isclose(jax_grad_tuple[1], np.asarray(2.43160535))
    assert check1 and check2
Esempio n. 5
0
def test_integral_generation_euler():
    tested_scheme = "euler"

    # Prepare values of means and second moments = E(XY)
    tested_integrals = [[1], [2], [0]]  # [dW1,dW2,dt]
    tested_labels = ["d_w1", "d_w2", "d_t"]
    target_means = jnp.array([
        pychastic.wiener_integral_moments.E(idx)(1) for idx in tested_integrals
    ])
    target_mean_products = jnp.array([
        pychastic.wiener_integral_moments.E2(idx, idy)(1)
        for (idx, idy) in itertools.product(tested_integrals, tested_integrals)
    ])

    samples_exponent = 14
    z_score_cutoff = 5

    seed = 0
    key = jax.random.PRNGKey(seed)
    sample_integrals = pychastic.vectorized_I_generation.get_wiener_integrals(
        key, scheme=tested_scheme, steps=2**samples_exponent, noise_terms=2)

    sample_integrals = jnp.array([
        sample_integrals["d_w"][:, 0],
        sample_integrals["d_w"][:, 1],
        jnp.ones_like(sample_integrals["d_w"][:, 0]),
    ]).T

    sample_means = jnp.mean(sample_integrals, axis=0)
    sample_mean_products = jnp.array([
        jnp.mean(x * y)
        for (x, y) in itertools.product(sample_integrals.T, sample_integrals.T)
    ])

    means_close = jnp.isclose(sample_means,
                              target_means,
                              atol=z_score_cutoff * 2**(-samples_exponent / 2))
    means_error = sample_means - target_means
    assert means_close.all(), "Expected values incorrect \n" + str({
        label: (bool(flag), float(error))
        for (flag, error,
             label) in zip(means_close, means_error, tested_labels)
    })

    products_close = jnp.isclose(
        sample_mean_products,
        target_mean_products,
        atol=z_score_cutoff * 2**(-samples_exponent / 2),
    )
    products_error = sample_mean_products - target_mean_products
    assert products_close.all(), "Expected products incorrect \n" + str({
        label: (bool(flag), float(error))
        for (flag, error, label) in zip(
            products_close,
            products_error,
            itertools.product(tested_labels, tested_labels),
        )
    })
Esempio n. 6
0
def stop(thresh, initial_state, state):
    clusters = state[-2]
    last_cluster = clusters[-2]

    return ~jnp.isclose(
        last_cluster, clusters, rtol=thresh
    ).any() &  jnp.isclose(
        state[-1], initial_state[-1], rtol=thresh
    )
def test_energy(wavelet_tensors):
    h, iso, dis = wavelet_tensors
    s = np.reshape(np.eye(2**3) / 2**3, [2] * 6)
    for _ in range(20):
        s = simple_mera.descend(h, s, iso, dis)
    en = np.trace(np.reshape(s, [2**3, -1]) @ np.reshape(h, [2**3, -1]))
    assert np.isclose(en, -1.242, rtol=1e-3, atol=1e-3)
    en = simple_mera.binary_mera_energy(h, s, iso, dis)
    assert np.isclose(en, -1.242, rtol=1e-3, atol=1e-3)
def test_descend(random_tensors):
    h, s, iso, dis = random_tensors
    s = simple_mera.descend(h, s, iso, dis)
    assert len(s.shape) == 6
    D = s.shape[0]
    smat = np.reshape(s, [D**3] * 2)
    assert np.isclose(np.trace(smat), 1.0)
    assert np.isclose(np.linalg.norm(smat - np.conj(np.transpose(smat))), 0.0)
    spec, _ = np.linalg.eigh(smat)
    assert np.alltrue(spec >= 0.0)
Esempio n. 9
0
def test_squared_norm():
    x = jnp.linspace(0., 1., 100)[:, None]
    y = jnp.linspace(1., 2., 50)[:, None]
    assert jnp.all(
        jnp.isclose(
            squared_norm(x, x),
            jnp.sum(jnp.square(x[:, None, :] - x[None, :, :]), axis=-1)))
    assert jnp.all(
        jnp.isclose(
            squared_norm(x, y),
            jnp.sum(jnp.square(x[:, None, :] - y[None, :, :]), axis=-1)))
Esempio n. 10
0
def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas):
    """For an arbitrary collection of ordered bonds and associated increments `(a, b, delta)`,
    update `charges` by `charges[a] += delta`, `charges[b] -= delta`

    Notes
    -----
    * preserves sum(initial_charges) for arbitrary values of bond_idxs or deltas
    * order within each row of bond_idxs is meaningful
        `(..., bond_idxs, deltas)`
        means the opposite of
        `(..., bond_idxs[:, ::-1], deltas)`
    * order within the first axis of bond_idxs, deltas is not meaningful
        `(..., bond_idxs[perm], deltas[perm])`
        means the same thing for any permutation `perm`
    """

    # apply bond charge corrections
    incremented = ops.index_add(initial_charges, bond_idxs[:, 0], +deltas)
    decremented = ops.index_add(incremented, bond_idxs[:, 1], -deltas)
    final_charges = decremented

    # make some safety assertions
    assert bond_idxs.shape[1] == 2
    assert len(deltas) == len(bond_idxs)

    net_charge = jnp.sum(initial_charges)
    net_charge_is_integral = jnp.isclose(net_charge,
                                         jnp.round(net_charge),
                                         atol=1e-5)

    final_net_charge = jnp.sum(final_charges)
    net_charge_is_unchanged = jnp.isclose(final_net_charge,
                                          net_charge,
                                          atol=1e-5)

    assert net_charge_is_integral
    assert net_charge_is_unchanged

    # print some safety warnings
    directed_bonds = Counter([tuple(b) for b in bond_idxs])
    undirected_bonds = Counter([tuple(sorted(b)) for b in bond_idxs])

    if max(directed_bonds.values()) > 1:
        duplicates = [
            bond for (bond, count) in directed_bonds.items() if count > 1
        ]
        print(UserWarning(f"Duplicate directed bonds! {duplicates}"))
    elif max(undirected_bonds.values()) > 1:
        duplicates = [
            bond for (bond, count) in undirected_bonds.items() if count > 1
        ]
        print(UserWarning(f"Duplicate undirected bonds! {duplicates}"))

    return final_charges
Esempio n. 11
0
 def __assert_rotation(R):
     if R.ndim != 2:
         print("R must be a matrix")
     a, b = R.shape
     if a != b:
         print("R must be square")
     if (not jnp.isclose(
             jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5)
         ) or (not jnp.isclose(
             jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)):
         print("R is not diagonal")
Esempio n. 12
0
    def test_function(self):

        y_true = jnp.array([[1.0, 1.0], [0.9, 0.0]])
        y_pred = jnp.array([[1.0, 1.0], [1.0, 0.0]])

        ## Standard MAPE
        mape_elegy = elegy.losses.MeanAbsolutePercentageError()
        mape_tfk = tfk.losses.MeanAbsolutePercentageError()
        assert jnp.isclose(mape_elegy(y_true, y_pred),
                           mape_tfk(y_true, y_pred),
                           rtol=0.0001)

        ## MAPE using sample_weight
        assert jnp.isclose(
            mape_elegy(y_true, y_pred, sample_weight=jnp.array([1, 0])),
            mape_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])),
            rtol=0.0001,
        )

        ## MAPE with reduction method: SUM
        mape_elegy = elegy.losses.MeanAbsolutePercentageError(
            reduction=elegy.losses.Reduction.SUM)
        mape_tfk = tfk.losses.MeanAbsolutePercentageError(
            reduction=tfk.losses.Reduction.SUM)
        assert jnp.isclose(mape_elegy(y_true, y_pred),
                           mape_tfk(y_true, y_pred),
                           rtol=0.0001)

        ## MAPE with reduction method: NONE
        mape_elegy = elegy.losses.MeanAbsolutePercentageError(
            reduction=elegy.losses.Reduction.NONE)
        mape_tfk = tfk.losses.MeanAbsolutePercentageError(
            reduction=tfk.losses.Reduction.NONE)
        assert jnp.all(
            jnp.isclose(mape_elegy(y_true, y_pred),
                        mape_tfk(y_true, y_pred),
                        rtol=0.0001))

        ## Prove the loss function
        rng = jax.random.PRNGKey(42)
        y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
        y_pred = jax.random.uniform(rng, shape=(2, 3))
        y_true = y_true.astype(y_pred.dtype)
        loss = elegy.losses.mean_percentage_absolute_error(y_true, y_pred)
        assert loss.shape == (2, )
        assert jnp.array_equal(
            loss,
            100 * jnp.mean(
                jnp.abs((y_pred - y_true) /
                        jnp.maximum(jnp.abs(y_true), utils.EPSILON)),
                axis=-1,
            ),
        )
Esempio n. 13
0
def get_rotation_pytree(src: Any, dst: Any) -> Any:
    """
    Takes two n-dimensional vectors/Pytree and returns an
    nxn rotation matrix mapping cjax to dst.
    Raises Value Error when unsuccessful.
    """
    def __assert_rotation(R):
        if R.ndim != 2:
            print("R must be a matrix")
        a, b = R.shape
        if a != b:
            print("R must be square")
        if (not jnp.isclose(
                jnp.abs(jnp.eye(a) - jnp.dot(R, R.T)).max(), 0.0, rtol=0.5)
            ) or (not jnp.isclose(
                jnp.abs(jnp.eye(a) - jnp.dot(R.T, R)).max(), 0.0, rtol=0.5)):
            print("R is not diagonal")

    if not pytree_shape_array_equal(src, dst):
        print("cjax and dst must be 1-dimensional arrays with the same shape.")

    x = pytree_normalized(src)
    y = pytree_normalized(dst)
    n = len(dst)

    # compute angle between x and y in their spanning space
    theta = jnp.arccos(jnp.dot(
        x, y))  # they are normalized so there is no denominator
    if jnp.isclose(theta, 0):
        print("x and y are co-linear")
    # construct the 2d rotation matrix connecting x to y in their spanning space
    R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                   [jnp.sin(theta), jnp.cos(theta)]])
    __assert_rotation(R)
    # get projections onto Span<x,y> and its orthogonal complement
    u = x
    v = pytree_normalized(pytree_sub(y, (jnp.dot(u, y) * u)))
    P = jnp.outer(u, u.T) + jnp.outer(
        v, v.T)  # projection onto 2d space spanned by x and y
    Q = jnp.eye(
        n) - P  # projection onto the orthogonal complement of Span<x,y>
    # lift the rotation matrix into the n-dimensional space
    uv = jnp.hstack((u[:, None], v[:, None]))

    R = Q + jnp.dot(uv, jnp.dot(R, uv.T))
    __assert_rotation(R)
    if jnp.any(jnp.logical_not(jnp.isclose(jnp.dot(R, x), y, rtol=0.25))):
        print("Rotation matrix did not work")
    return R
Esempio n. 14
0
def test_cumulative_logsumexp():
    a = jnp.linspace(-1.,1.,100)
    v1 = jnp.log(jnp.cumsum(jnp.exp(a)))
    v2 = cumulative_logsumexp(a)
    print(v1)
    print(v2)
    assert jnp.isclose(v1,v2).all()
Esempio n. 15
0
    def from_pairs(cls,
                   pairs,
                   scale: Scale,
                   normalized=False,
                   interpolate=True):
        sorted_pairs = sorted([(v["x"], v["density"]) for v in pairs])
        xs = np.array([x for (x, density) in sorted_pairs])
        densities = np.array([density for (x, density) in sorted_pairs])

        if not normalized:
            xs = scale.normalize_points(xs)
            densities = scale.normalize_densities(xs, densities)

        if interpolate:
            # interpolate ps at target_xs
            if not (len(xs) == len(constants.target_xs)
                    and np.isclose(xs, constants.target_xs, rtol=1e-04).all()):
                f = interp1d(xs, densities)
                densities = f(constants.target_xs)

        # Make sure AUC is 1
        auc = np.sum(densities) / densities.size
        densities /= auc

        return cls(constants.target_xs,
                   densities,
                   scale=scale,
                   normalized=True)
Esempio n. 16
0
def isclose(x, y, atol=1e-8):
    if isinstance(x, np.ndarray):
        return np.isclose(x, y, atol=atol).all()
    elif isinstance(x, list):
        # return all(np.isclose(x[0], y[0], atol=1e-03).all() for i in range(len(x)))
        return np.isclose(search_spaces.build(x),
                          search_spaces.build(y),
                          atol=atol).all()
    elif isinstance(x, tuple) and isinstance(x[0], np.ndarray):
        return np.isclose(x[0], y[0], atol=atol).all()
    elif isinstance(x, tuple) and isinstance(x[0], list):
        return np.isclose(search_spaces.build(x[0]),
                          search_spaces.build(y[0]),
                          atol=atol).all()
    else:
        raise ValueError('wrong format')
 def conditional_mean(self, x, n):
     x = x.reshape(-1, self.idx)
     pi, mu, var = self.condition(x, self.idx)
     if not np.isclose(np.sum(pi), 1.):
         pi = self._pi
     weighted_mean = np.sum(pi * mu.reshape(x.shape[0], -1), 1)
     return weighted_mean
Esempio n. 18
0
    def test_shuffled_mask_sparsity_empty_twolayer(self):
        """Tests shuffled mask generation for two layers, for 0% sparsity."""
        mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng,
                                    0.0)

        with self.subTest(name='shuffled_empty_mask_layer1'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_empty_mask_values_layer1'):
            self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())

        with self.subTest(name='shuffled_empty_mask_layer2'):
            self.assertIn('MaskedModule_1', mask)

        with self.subTest(name='shuffled_empty_mask_values_layer2'):
            self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())

        masked_output = self._masked_model_twolayer(self._input, mask=mask)

        with self.subTest(name='shuffled_empty_dense_values'):
            self.assertTrue(
                jnp.isclose(masked_output,
                            self._unmasked_output_twolayer).all())

        with self.subTest(name='shuffled_empty_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output_twolayer.shape)
def test_ascend(random_tensors):
    h, s, iso, dis = random_tensors
    h = simple_mera.ascend(h, s, iso, dis)
    assert len(h.shape) == 6
    D = h.shape[0]
    hmat = np.reshape(h, [D**3] * 2)
    assert np.isclose(np.linalg.norm(hmat - np.conj(np.transpose(hmat))), 0.0)
Esempio n. 20
0
    def __init__(self, n: jnp.ndarray, p: jnp.ndarray):
        """Initializes a multinomial distribution with n trials and probabilities p.

        n may be multidimensional, in which case it represents
        multiple multinomial distributions.

        p has to have the shape of n plus 1 dimension representing the
        the probabilities of each event. The probabilities in the last
        dimension have to sum to 1.

        Args:
            n: Number of trials. Has to be an integer and non-negative.
            p: Probabilities of trial successes. Must have same shape
                as n + 1 additional dimension representing the probabilities.
                Probabilities have to sum to 1.
        """
        super().__init__()

        if n.shape != p.shape[:len(n.shape)] or \
                len(n.shape) + 1 != len(p.shape):
            raise ValueError('Shapes of n and p not compatible')

        # we cannot raise a ValueError here since we get problems with
        # ConcretizationError during Metropolis-Hastings
        nans = jnp.full(p.shape, jnp.nan)
        self.p = jnp.where(jnp.logical_or(p < 0, p > 1), nans, p)
        self.p = jnp.where(jnp.isclose(jnp.sum(self.p, -1, keepdims=True), 1),
                           self.p, nans)

        nans = jnp.full(n.shape, jnp.nan)
        self.n = jnp.where(n <= 0, nans, n)
Esempio n. 21
0
def test_simple():
    p = minjax.Problem()
    x = onp.array(0., dtype=np.float32)
    p.add_residual(l2_loss(simple_cost), x)
    p.solve(verbose=True)
    print(x)
    assert np.isclose(x, 5., atol=1e-2)
Esempio n. 22
0
def test_compatibility():

    # Input:  true (y_true) and predicted (y_pred) tensors
    y_true = jnp.array([[0.0, 1.0], [0.0, 0.0]])
    y_pred = jnp.array([[0.6, 0.4], [0.4, 0.6]])

    # Standard BCE, considering prediction tensor as probabilities
    bce_elegy = elegy.losses.BinaryCrossentropy()
    bce_tfk = tfk.losses.BinaryCrossentropy()
    assert jnp.isclose(bce_elegy(y_true, y_pred),
                       bce_tfk(y_true, y_pred),
                       rtol=0.0001)

    # Standard BCE, considering prediction tensor as logits
    y_logits = jnp.log(y_pred) - jnp.log(1 - y_pred)
    bce_elegy = elegy.losses.BinaryCrossentropy(from_logits=True)
    bce_tfk = tfk.losses.BinaryCrossentropy(from_logits=True)
    assert jnp.isclose(bce_elegy(y_true, y_logits),
                       bce_tfk(y_true, y_logits),
                       rtol=0.0001)

    # BCE using sample_weight
    bce_elegy = elegy.losses.BinaryCrossentropy()
    bce_tfk = tfk.losses.BinaryCrossentropy()
    assert jnp.isclose(
        bce_elegy(y_true, y_pred, sample_weight=jnp.array([1, 0])),
        bce_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])),
        rtol=0.0001,
    )

    # BCE with reduction method: SUM
    bce_elegy = elegy.losses.BinaryCrossentropy(
        reduction=elegy.losses.Reduction.SUM)
    bce_tfk = tfk.losses.BinaryCrossentropy(reduction=tfk.losses.Reduction.SUM)
    assert jnp.isclose(bce_elegy(y_true, y_pred),
                       bce_tfk(y_true, y_pred),
                       rtol=0.0001)

    # BCE with reduction method: NONE
    bce_elegy = elegy.losses.BinaryCrossentropy(
        reduction=elegy.losses.Reduction.NONE)
    bce_tfk = tfk.losses.BinaryCrossentropy(
        reduction=tfk.losses.Reduction.NONE)
    assert jnp.all(
        jnp.isclose(bce_elegy(y_true, y_pred),
                    bce_tfk(y_true, y_pred),
                    rtol=0.0001))
Esempio n. 23
0
 def test_sympy2jax(self):
     x, y, z = sympy.symbols("x y z")
     cosx = 1.0 * sympy.cos(x) + y
     key = random.PRNGKey(0)
     X = random.normal(key, (1000, 2))
     true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
     f, params = sympy2jax(cosx, [x, y, z])
     self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
Esempio n. 24
0
def test_logaddexp():
    a = jnp.log(1.)
    b = jnp.log(1.)
    assert logaddexp(a,b) == jnp.log(2.)
    a = jnp.log(1.)
    b = jnp.log(-2.+0j)
    assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, -1.)

    a = jnp.log(-1.+0j)
    b = jnp.log(2. + 0j)
    assert jnp.isclose(jnp.exp(logaddexp(a, b)).real, 1.)

    for i in range(100):
        u = random.uniform(random.PRNGKey(i),shape=(2,))*20. - 10.
        a = jnp.log(u[0] + 0j)
        b = jnp.log(u[1] + 0j)
        assert jnp.isclose(jnp.exp(logaddexp(a,b)).real, u[0] + u[1])
Esempio n. 25
0
def check_saddle_point(step, y, prev_y, energy, prev_energy):
    if np.all(np.max(np.abs(y - prev_y), axis=1) < 0.01):
        logger.debug(f'achieve tolerance with y_hat at {step}th step')
        return True
    if np.all(np.isclose(energy, prev_energy)):
        logger.debug(f'achieve tolerance with energy at {step}th step')
        return True
    return False
Esempio n. 26
0
def safe_norm(a, ord=2, axis=None):
    if axis is not None:
        is_zero = jnp.expand_dims(jnp.isclose(jnp.sum(a, axis=axis), 0.), axis=axis)
    else:
        is_zero = jnp.ones_like(a, dtype='bool')
    norm = jnp.linalg.norm(a + jnp.where(is_zero, jnp.ones_like(a) * 1e-5 ** ord, jnp.zeros_like(a)), ord=ord,
                           axis=axis)
    return norm
Esempio n. 27
0
def test_fourier_approx():
    grid = Grid(lower=(-np.pi, ),
                upper=(np.pi, ),
                shape=(512, ),
                periodic=True)

    x_scaled = compute_mesh(grid)
    x = np.pi * x_scaled

    # Periodic function and its gradient
    y = jax.vmap(f)(x.flatten()).reshape(x.shape)
    dy = jax.vmap(jax.grad(f))(x.flatten()).reshape(x.shape)

    model = SpectralGradientFit(grid)
    fit = build_fitter(model)
    evaluate = build_evaluator(model)
    get_grad = build_grad_evaluator(model)

    fun = fit(dy)

    assert np.all(np.isclose(y, evaluate(fun, x_scaled))).item()
    assert np.all(np.isclose(dy, get_grad(fun, x_scaled))).item()

    # fig, ax = plt.subplots()
    # ax.plot(x, dy)
    # ax.plot(x, get_grad(fun, x_scaled))
    # plt.show()

    # fig, ax = plt.subplots()
    # ax.plot(x, y)
    # ax.plot(x, evaluate(fun, x_scaled))
    # plt.show()

    model = SpectralSobolev1Fit(grid)
    fit = build_fitter(model)
    evaluate = build_evaluator(model)
    get_grad = build_grad_evaluator(model)

    sfun = fit(y, dy)

    assert np.all(np.isclose(y, evaluate(sfun, x_scaled))).item()
    assert np.all(np.isclose(dy, get_grad(sfun, x_scaled))).item()

    assert np.linalg.norm(fun.coefficients - sfun.coefficients) < 1e-8
Esempio n. 28
0
    def _predict_transition_probabilities_jax(
        X: np.ndarray,
        W: np.ndarray,
        softmax_scale: float = 1.0,
        center_mean: bool = True,
        scale_by_norm: bool = True,
    ):
        if center_mean:
            # pearson correlation, otherwise cosine
            W -= W.mean(axis=1)[:, None]
            X -= X.mean()

        if scale_by_norm:
            denom = jnp.linalg.norm(X) * jnp.linalg.norm(W, axis=1)
            mask = jnp.isclose(denom, 0)
            denom = jnp.where(jnp.isclose(denom, 0), 1, denom)  # essential
            return _softmax_masked_jax(W.dot(X) / denom, mask, softmax_scale)

        return _softmax_jax(W.dot(X), softmax_scale)
Esempio n. 29
0
def test_max(data):
    model = data.draw(sampled_from([CKY_CRF, DepTree]))
    struct = model(MaxSemiring)
    vals, (batch, N) = test_lookup[model]._rand()
    vals_jax = struct.resize(np.array(vals.numpy()))
    Ns = np.array([N] * vals_jax.shape[0])
    score = struct.sum(vals_jax, Ns)
    marginals = struct.marginals(vals_jax, Ns)
    print(marginals)
    assert np.isclose(score, struct.score(vals_jax, marginals)).all()
Esempio n. 30
0
def test_api(opt_id):
    """ Description: verify that all default methods work for this optimizer """
    optimizer_class = tigercontrol.optimizer(opt_id)
    opt = optimizer_class()
    n, m = opt.get_state_dim(), opt.get_action_dim(
    )  # get dimensions of system
    x_0 = opt.reset()  # first state
    x_1 = opt.step(np.zeros(m))  # try step with 0 action
    assert np.isclose(
        x_0, opt.reset())  # assert reset return back to original state