Exemplo n.º 1
0
 def f1(x):
     return np.sin(np.sin(np.sin(x)))
Exemplo n.º 2
0
 def f(x):
     _, y = lax.while_loop(lambda s: s[0] < 0., lambda s:
                           (jnp.sin(s[0]), jnp.cos(s[1])), (x, x))
     return y + 1.
Exemplo n.º 3
0
def simple_fun_fanout(x, y):
    return jnp.sin(x * y) * x
Exemplo n.º 4
0
 def test_check_jaxpr_correct(self):
     jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
     core.check_jaxpr(jaxpr)
Exemplo n.º 5
0
 def f(x):
     return jnp.sin(x) + jnp.cos(x)
Exemplo n.º 6
0
 def test_nested_jit(self):
   f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
   f_tf = jax2tf.convert(f_jax)
   np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))
Exemplo n.º 7
0
def product_io_fun(x, y):
    xa = x['a']
    xb = x['b']
    y1, (y2, y3) = y
    return jnp.sin(xa + y2), [xb, (y1, y3)]
Exemplo n.º 8
0
 def foo(x):
     return np.sin(x)
Exemplo n.º 9
0
 def foo(x):
     return np.sin(2. * x)
Exemplo n.º 10
0
 def test_complex_output_jacrev_raises_error(self):
     self.assertRaises(TypeError, lambda: jacrev(lambda x: np.sin(x))
                       (1 + 2j))
Exemplo n.º 11
0
 def test_complex_input_jacfwd_raises_error(self):
     self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))
                       (1 + 2j))
Exemplo n.º 12
0
 def test_holomorphic_grad(self):
     out = grad(lambda x: np.sin(x), holomorphic=True)(1 + 2j)
     expected = 2.0327230070196656 - 3.0518977991518j
     self.assertAllClose(out, expected, check_dtypes=False)
Exemplo n.º 13
0
 def test_complex_grad_raises_error(self):
     self.assertRaises(TypeError, lambda: grad(lambda x: np.sin(x))(1 + 2j))
Exemplo n.º 14
0
 def f1(x, y):
     return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y)
Exemplo n.º 15
0
 def test_variable_input(self):
   f_jax = lambda x: jnp.sin(jnp.cos(x))
   f_tf = jax2tf.convert(f_jax)
   v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_))
   self.assertIsInstance(f_tf(v), tf.Tensor)
   self.assertAllClose(f_jax(0.7), f_tf(v))
Exemplo n.º 16
0
 def foo(x, y):
     return np.sin(x * y)
Exemplo n.º 17
0
 def test_jit(self):
   f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
   self.ConvertAndCompare(f_jax, jnp.float_(0.7))
Exemplo n.º 18
0
 def h(x):
     return np.sin(np.cos(x))
Exemplo n.º 19
0
 def foo(y, z):
     return (x * x) * jnp.sin(y) * z
Exemplo n.º 20
0
 def pol2cart(theta, rho):
     x = (rho * np.cos(theta)).reshape(-1, 1)
     y = (rho * np.sin(theta)).reshape(-1, 1)
     return np.concatenate([x, y], axis=1)
Exemplo n.º 21
0
 def bar(y):
     return jnp.sin(x * y)
Exemplo n.º 22
0
def test_Cdo_timeseries(plot=False):
    if plot:
        import pylab as pl
    x = np.linspace(0, 40, 400).reshape((-1, 1))
    y = np.sin(x) + randn(len(x)).reshape((-1, 1)) * 0.2
    proc_data = np.hstack([x, y])
    if plot:
        pl.plot(x.flatten(), y.flatten())

    invec = FiniteVec(GaussianKernel(0.5),
                      np.array([y.squeeze()[i:i + 10] for i in range(190)]))
    outvec = FiniteVec(GaussianKernel(0.5), y[10:200])
    refervec = FiniteVec(
        outvec.k,
        np.linspace(y[:-201].min() - 2, y[:-201].max() + 2, 5000)[:, None])
    cd = Cdo(invec, outvec, refervec, 0.1)
    cd = Cmo(invec, outvec, 0.1)
    sol2 = np.array([
        multiply(cd, FiniteVec(invec.k,
                               y[end - 10:end].T)).normalized().get_mean_var()
        for end in range(200, 400)
    ])
    if plot:
        pl.plot(x[200:].flatten(), sol2.T[0].flatten())
    invec = CombVec(
        FiniteVec(PeriodicKernel(np.pi, 5), x[:200, :]),
        SpVec(SplitDimsKernel([0, 1, 2],
                              [PeriodicKernel(np.pi, 5),
                               GaussianKernel(0.1)]),
              proc_data[:200, :],
              np.array([200]),
              use_subtrajectories=True), np.multiply)
    outvec = FiniteVec(GaussianKernel(0.5), y[1:-199])
    #cd = Cdo(invec, outvec, refervec, 0.1)
    cd = Cmo(invec, outvec, 0.1)
    #sol = (cd.inp_feat.inner(SpVec(invec.k, proc_data[:230], np.array([230]), use_subtrajectories=True)))
    #sol = [(cd.inp_feat.inner(SiEdSpVec(invec.k_obs, y[:end], np.array([end]), invec.k_idx, use_subtrajectories=False ))) for end in range(200,400) ]
    #pl.plot(np.array([sol[i][-1] for i in range(len(sol))]))

    #sol = np.array([multiply (cd, SpVec(invec.k, proc_data[:end], np.array([end]), use_subtrajectories=False)).normalized().get_mean_var() for end in range(200,400) ])
    sol = multiply(
        cd,
        CombVec(
            FiniteVec(invec.v1.k, x),
            SpVec(invec.v2.k,
                  proc_data[:400],
                  np.array([400]),
                  use_subtrajectories=True),
            np.multiply)).normalized().get_mean_var()

    print(sol)
    return sol2.T[0], sol.T[0][200:], y[200:]
    (true_x1, est_x1, este_x1, true_x2, est_x2, este_x2) = [
        lambda samps: true_dens(
            np.hstack([np.repeat(x1, len(samps), 0), samps])),
        lambda samps: np.squeeze(
            inner(
                multiply(cd, FiniteVec.construct_RKHS_Elem(invec.k, x1)).
                normalized().unsigned_projection().normalized(),
                FiniteVec(refervec.k, samps, prefactors=np.ones(len(samps))))),
        lambda samps: np.squeeze(
            inner(
                multiply(cm, FiniteVec.construct_RKHS_Elem(invec.k, x1)).
                normalized().unsigned_projection().normalized(),
                FiniteVec(refervec.k, samps, prefactors=np.ones(len(samps))))),
        lambda samps: true_dens(
            np.hstack([np.repeat(x2, len(samps), 0), samps])),
        lambda samps: np.squeeze(
            inner(
                multiply(cd, FiniteVec.construct_RKHS_Elem(invec.k, x2)).
                normalized().unsigned_projection().normalized(),
                FiniteVec(refervec.k, samps, prefactors=np.ones(len(samps))))),
        lambda samps: np.squeeze(
            inner(
                multiply(cm, FiniteVec.construct_RKHS_Elem(invec.k, x2)).
                normalized().unsigned_projection().normalized(),
                FiniteVec(refervec.k, samps, prefactors=np.ones(len(samps)))))
    ]

    t = np.array(
        (true_x1(refervec.inspace_points), true_x2(refervec.inspace_points)))
    e = np.array(
        (est_x1(refervec.inspace_points), est_x2(refervec.inspace_points)))
    if plot:
        import pylab as pl

        (fig, ax) = pl.subplots(1, 3, False, False)
        ax[0].plot(refervec.inspace_points, t[0])
        ax[0].plot(refervec.inspace_points, e[0], "--", label="dens")
        ax[0].plot(refervec.inspace_points,
                   este_x1(refervec.inspace_points),
                   "-.",
                   label="emb")

        ax[1].plot(refervec.inspace_points, t[1])
        ax[1].plot(refervec.inspace_points, e[1], "--", label="dens")
        ax[1].plot(refervec.inspace_points,
                   este_x2(refervec.inspace_points),
                   "-.",
                   label="emb")

        ax[2].scatter(*rvs.T)
        fig.legend()
        fig.show()
    assert (np.allclose(e, t, atol=0.5))
Exemplo n.º 23
0
 def f(c, x):
     b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c)))
     c = jnp.sin(c * b)
     return c, b
Exemplo n.º 24
0
 def test_basics(self):
   f_jax = lambda x: jnp.sin(jnp.cos(x))
   _, res_tf = self.ConvertAndCompare(f_jax, jnp.float_(0.7))
Exemplo n.º 25
0
def simple_fun(x, y):
    return jnp.sin(x * y)
Exemplo n.º 26
0
 def f(x1):
   x2 = jnp.sin(x1)
   x3 = jnp.sin(x2)
   x4 = jnp.sin(x3)
   return jnp.sum(x4)
Exemplo n.º 27
0
 def f(x):
     _, y = lax.cond(x < 0., lambda x: (jnp.sin(x), x + 1.), lambda x:
                     (jnp.cos(x), x + 2.), x)
     return y
Exemplo n.º 28
0
 def f_jax():
   return jnp.sin(1.)
Exemplo n.º 29
0
 def f(x):
   return np.sin(x)
Exemplo n.º 30
0
 def f(x):
     g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
     return aux[0] * np.sin(x)