Esempio n. 1
0
    def __fitComplete2pMLE(self):
        # initial guess:
        shape = 1.2
        scale = self.failures.mean()
        parameters = jnp.array([shape, scale])

        J = jacfwd(self.__logLikelihood2pComp)
        H = jacfwd(jacrev(self.__logLikelihood2pComp))

        epoch = 0
        total = 1
        while not (total < 0.01 or epoch > 200):
            epoch += 1
            grads = J(parameters)
            hess = linalg.inv(H(parameters))
            # Q is a coefficient to reduce gradient ascent step for high delta
            q = 1 / (1 + jnp.sqrt(abs(grads / 800)))
            # Newton-Raphson maximisation
            parameters -= q * hess @ grads
            total = abs(grads[0]) + abs(grads[1])

        if epoch < 200:
            self.converged = True
            self.shape = parameters[0]
            self.scale = parameters[1]
            self.method = '2pComplete'

            # Fisher Matrix confidence bound
            self.variance = [abs(hess[0][0]), abs(hess[1][1])]
            self.beta_eta_covar = [abs(hess[1][0])]

        else:
            # if more than 200 epoch it would be considered that fit is not converged
            self.converged = False
            self.shape = 0.0
            self.scale = 0.0
            self.method = Method.MLEComplete2p
Esempio n. 2
0
    def derivative(self, wrt):
        if isinstance(wrt, str):
            wrt_tuple = wrt,
            return self.derivative(wrt_tuple)

        # Trivial case, i.e., 0-th derivative
        if wrt == ():
            return self.fun

        # Return the registered derivative, if it exists
        try:
            return self.derivatives[wrt]
        except KeyError:
            pass

        # Compute the derivative
        assert len(wrt) >= 1
        fun = self.derivative(wrt[1:])
        argnum = self.argnum(wrt[0])
        deriv = jax.jacrev(fun, argnum)

        # Save it and return
        self.derivatives[wrt] = deriv
        return deriv
Esempio n. 3
0
    def __fitTypeICensored2pMLE(self):
        # initial guess:
        shape = 1.2
        scale = (self.failures.mean() + self.censored.mean()) / 2
        parameters = jnp.array([shape, scale])

        J = jacfwd(self.__logLikelihood2pTypeICensored)
        H = jacfwd(jacrev(self.__logLikelihood2pTypeICensored))

        epoch = 0
        total = 1
        while not (total < 0.09 or epoch > 200):
            epoch += 1
            grads = J(parameters)
            hess = linalg.inv(H(parameters))
            q = 1 / (
                1 + jnp.sqrt(abs(grads / 8))
            )  # Q is a coefficient to reduce gradient ascent step for high delta
            parameters -= q * hess @ grads  # Newton-Raphson maximisation
            total = abs(grads[0]) + abs(grads[1])

        if epoch < 200:
            self.converged = True
            self.shape = parameters[0]
            self.scale = parameters[1]
            self.method = Method.MLECensored2p
            self.variance = [abs(hess[0][0]), abs(hess[1][1])]
            self.beta_eta_covar = [abs(hess[1][0])]

        else:
            # if more than 200 epoch it would be considered that fit is not converged
            self.converged = False
            self.shape = None
            self.scale = None
            self.method = Method.MLECensored2p
            print('no')
Esempio n. 4
0
    def __init__(self, g=10.0):
        self.max_speed = 20.
        #self.max_torque=1.
        self.max_torque = 3.  # INCREASED TORQUE
        self.dt = .05
        self.g = g
        self.action_space = (1, )
        self.observation_space = (2, )
        self.n, self.m = 2, 1

        @jax.jit
        def angle_normalize(x):
            x = np.where(x > np.pi, x - 2 * np.pi, x)
            x = np.where(x < -np.pi, x + 2 * np.pi, x)
            return x

        self.angle_normalize = angle_normalize

        @jax.jit
        def _dynamics(x, u):
            th, th_dot = x
            g = self.g
            m = 1.
            l = 1.
            dt = self.dt
            u = np.clip(u, -self.max_torque, self.max_torque)[0]
            th_dot_dot = (-3. * g) / (2. * l) * np.sin(th + np.pi) + 3. / (
                m * l**2) * u
            new_th = self.angle_normalize(th + th_dot * dt)
            new_th_dot = th_dot + th_dot_dot * dt
            new_th_dot = np.clip(new_th_dot, -self.max_speed, self.max_speed)
            return np.array([new_th, new_th_dot])

        self._dynamics = _dynamics
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        self.dynamics_jacobian = jax.jit(lambda x, u: jacobian(x, u))
Esempio n. 5
0
def hessian(f, argnums=0):
    return jax.jacfwd(jax.jacrev(f, argnums), argnums)
Esempio n. 6
0
@jax.jit
def single_step_loss(theta, x, t_current, T):
    L = 0
    g = loss_grad(x)
    lr = jnp.exp(theta[0]) * (T - t_current) / T + jnp.exp(
        theta[1]) * t_current / T
    x_new = x - lr * g
    L += loss(x_new) * (t_current < T)
    return L


compute_dL_dstate_old = jax.jit(jax.grad(single_step_loss, argnums=1))
compute_dL_dtheta_direct = jax.jit(jax.grad(single_step_loss, argnums=0))

compute_d_state_new_d_state_old = jax.jit(jax.jacrev(single_step, argnums=1))
compute_d_state_new_d_theta_direct = jax.jit(jax.jacrev(single_step,
                                                        argnums=0))


def rtrl_grad(theta, x, t0, T, K, dstate_dtheta=None):
    t_current = t0
    mystate = x
    total_loss = 0.0

    if dstate_dtheta is None:
        dstate_dtheta = jnp.zeros((len(mystate), len(theta)))

    total_theta_grad = 0
    total_loss = 0.0
Esempio n. 7
0
        a0, a1, b0, b1 = c[i:i + 4]

        alpha_n = 1 - a0**2 + a1 * (1 - a0) * jnp.cos(w)
        beta_n = a0**2 + a1**2 + 1 + 2 * a0 * 1 * (
            2 * jnp.cos(w)**2 - 1) + 2 * a1 * (a0 + 1) * jnp.cos(w)
        alpha_d = 1 - b0**2 + b1 * (1 - b0) * jnp.cos(w)
        beta_d = b0**2 + b1**2 + 1 + 2 * b0 * 1 * (
            2 * jnp.cos(w)**2 - 1) + 2 * b1 * (b0 + 1) * jnp.cos(w)

        group_delay += -alpha_n / beta_n + alpha_d / beta_d

    return group_delay


group_delay_gradient = jax.jacrev(_group_delay)


def _group_delay_deviation(x, w):
    """
    Calculates the group delay deviation for filter with coefficients x for the given frequencies in w
    
    parameters
    ----------
    
    x: ndarray
        list of all coefficients of all seconds order stages and tau, the group delay optimization variable:
        [c tau]

    w: ndarray
        frequency bins in the range [0, π] to evaluate the group delay on
Esempio n. 8
0
def initialize_functions(hyper_params, phi, psi, g, get_params_from_opt, opt_update):
    psi_z = jacrev(psi, argnums=1)
    psi = jit(psi)
    g = jit(g)

    phi_x = jacrev(phi, argnums=1)

    def l1_regularization(params, penalty):
        val = 0.0
        for layer in params:
            for param in layer:
                val += jnp.sum(jnp.abs(param))
        return val * penalty

    def l2_regularization(params, penalty):
        val = 0.0
        for layer in params:
            for param in layer:
                val += jnp.sum(param ** 2)
        return val * penalty

    def T(params_all, x, dx, z_ref):
        params_phi = params_all[:hyper_params['n_phi']]
        params_psi = params_all[hyper_params['n_phi']:hyper_params['n_phi'] + hyper_params['n_psi']]
        params_g = params_all[hyper_params['n_phi'] + hyper_params['n_psi']:]

        z_opt = phi(params_phi, x)
        dz_g = g(params_g, z_opt, z_ref)
        dx_recon = jnp.dot(psi_z(params_psi, z_opt, z_ref), dz_g)
        z_opt_x = phi_x(params_phi, x)

        x_loss = jnp.sum((x - psi(params_psi, z_opt, z_ref)) ** 2)
        dx_loss = hyper_params['eta1'] * jnp.sum((dx - dx_recon) ** 2)
        dz_loss = hyper_params['eta2'] * jnp.sum((jnp.dot(z_opt_x, dx) - dz_g) ** 2)
        regul = l2_regularization(params_g, hyper_params['eta3'])

        loss = x_loss + dx_loss + dz_loss + regul
        return loss

    def T_seperate(params_all, x, dx, z_ref):
        params_phi = params_all[:hyper_params['n_phi']]
        params_psi = params_all[hyper_params['n_phi']:hyper_params['n_phi'] + hyper_params['n_psi']]
        params_g = params_all[hyper_params['n_phi'] + hyper_params['n_psi']:]

        z_opt = phi(params_phi, x)
        dz_g = g(params_g, z_opt, z_ref)
        dx_recon = jnp.dot(psi_z(params_psi, z_opt, z_ref), dz_g)
        z_opt_x = phi_x(params_phi, x)

        x_loss = jnp.sum((x - psi(params_psi, z_opt, z_ref)) ** 2)
        dx_loss = hyper_params['eta1'] * jnp.sum((dx - dx_recon) ** 2)
        dz_loss = hyper_params['eta2'] * jnp.sum((jnp.dot(z_opt_x, dx) - dz_g) ** 2)
        regul = l2_regularization(params_g, hyper_params['eta3'])

        loss = x_loss + dx_loss + dz_loss + regul
        return loss, x_loss, dx_loss, dz_loss, dx_recon, regul

    T_params = grad(T, argnums=0)

    # vectorized functions
    psi_vec = jit(vmap(psi, (None, 0, 0)))
    phi_vec = jit(vmap(phi, (None, 0)))
    T_seperate_vec = jit(vmap(T_seperate, (None, 0, 0, 0)))
    T_params_vec = jit(vmap(T_params, (None, 0, 0, 0)))

    @jit
    def update(i, opt_state, x, dx, z_ref):
        params = get_params_from_opt(opt_state)
        grads = T_params_vec(params, x, dx, z_ref)
        grads_mean = []
        for i, layer in enumerate(grads):
            if (len(layer) == 0):
                grads_mean.append(())
            else:
                grads_mean.append(tuple([jnp.mean(weight, axis=0) for weight in layer]))
        return opt_update(i, grads_mean, opt_state)

    return update, T_seperate_vec, phi_vec, psi_vec, phi, T, T_params, g
Esempio n. 9
0
 def hessian(f):
   return jacfwd(jacrev(f))
Esempio n. 10
0
    return (aa * bb).mean()


random_features_kernel, grad_random_features_kernel = get_mapped_kernel_grad(
    random_features_kernel_basic, n_params=3)


def matrix_valued_kernel_basic(a, b, Q, h):
    diff = (a - b)
    return np.linalg.inv(Q) * np.exp(-0.5 / h * diff.T @ Q @ diff)


mv = vmap(matrix_valued_kernel_basic, (0, None, None, None), 0)
matrix_valued_kernel = jit(vmap(mv, (None, 0, None, None), 1))

jac_exp = jacrev(matrix_valued_kernel_basic)


def grad_update(a, b, Q, h, g, repulsion):
    K = matrix_valued_kernel_basic(a, b, Q, h)
    K_der = jac_exp(a, b, Q, h)
    res = np.zeros(a.shape[-1])
    for l in range(res.shape[-1]):
        x = 0
        for m in range(Q.shape[-1]):
            x += K[l, m] * g[m] + repulsion * K_der[l, m, m]
        res = index_update(res, index[l], x)
    return res


mv = vmap(grad_update, (0, None, None, None, 0, None), 0)
Esempio n. 11
0
 def test_logdet_hess_da(self):
     self.check_logdet_jac(lambda f: jax.jacfwd(jax.jacrev(f)),
                           hess=True,
                           da=True)
Esempio n. 12
0
 def test_quad_matrix_matrix_hess_da(self):
     self.check_quad_jac(lambda f: jax.jacfwd(jax.jacrev(f)),
                         self.randmat,
                         self.randmat,
                         hess=True,
                         da=True)
Esempio n. 13
0
 def test_solve_matrix_hess_da(self):
     self.check_solve_jac(self.randmat,
                          lambda f: jax.jacfwd(jax.jacrev(f)),
                          hess=True,
                          da=True)
Esempio n. 14
0
    normaly = dzdphi * dxdtheta - dxdphi * dzdtheta
    normalz = dxdphi * dydtheta - dydphi * dxdtheta
    norm_normal = jnp.sqrt(normalx * normalx + normaly * normaly +
                           normalz * normalz)
    area = nfp * dtheta * dphi * jnp.sum(norm_normal)
    # Compute plasma volume using \int (1/2) R^2 dZ dphi
    # = \int (1/2) R^2 (dZ/dtheta) dtheta dphi
    volume = 0.5 * nfp * dtheta * dphi * jnp.sum(r * r * dzdtheta)
    return jnp.array([area, volume])


# area_volume_pure(rc, rs, zc, zs, stellsym, nfp, mpol, ntor, ntheta, nphi)
# jit_area_volume_pure = jit(area_volume_pure, static_argnums=(4, 5, 6, 7, 8, 9))
# jit_area_volume_pure = jit(area_volume_pure, static_argnums=(8, 9))
jit_area_volume_pure = area_volume_pure
darea_volume_pure = jacrev(area_volume_pure, argnums=(0, 1, 2, 3))

jit_dataset_area_volume = dataset_area_volume


class Surface(abc.ABC):
    """
    Surface is a base class for various representations of toroidal
    surfaces in simsopt.
    """
    def __init__(self, nfp=1, stellsym=True):
        #if not isinstance(nfp, int):
        #    raise TypeError('nfp must be an integer')
        #if not isbool(stellsym):
        #    raise TypeError('stellsym must be a bool')
        self.nfp = nfp
Esempio n. 15
0
 def __init__(self, neglog, d, metric_fun):
     self.__neglog = neglog
     self.__hessian_fun = jax.jit(jax.jacfwd(jax.jacrev(self.__neglog)))
     self.__metric_fun = metric_fun
     self.__d = d
     self.__softabs_const = 1e0
Esempio n. 16
0
 def dlZ_dm(self, y, x, w, cav_mean, cav_var, power):
     return jacrev(self.log_expected_likelihood,
                   argnums=3)(y, x, w, cav_mean, cav_var, power)
Esempio n. 17
0
 def jacobian(self, func, state, action):
     return jax.jacrev(func, argnums=(0, 1))(state, action)
Esempio n. 18
0
phi_vars = np.array([np.pi / 2., np.pi / 3.])
# phi_vars = np.array([np.pi/2.])
loc = 1.0

print(encode_point(loc, phi_vars))
print("")
print(grad_encode_point_x(loc, phi_vars))
print("")
print(grad_encode_point_phis(loc, phi_vars))
print("")

assert False

# trying it with a batch dimension
batch_encode_point = vmap(encode_point, (0, None))
grad_encode_point_x = vmap(jacrev(encode_point, argnums=0), (0, None))
grad_encode_point_phis = vmap(jacrev(encode_point, argnums=1), (0, None))

loc = numpy.array([
    [0.0],
    [1.2],
    [1.4],
    [1.5],
])

# phi_vars = numpy.array([
#     [np.pi/2., np.pi/3.],
#     [np.pi/2., np.pi/3.],
#     [np.pi/8., 1.0],
#     [-np.pi/7., np.pi/4.],
# ])
Esempio n. 19
0
    def __init__(self, **kwargs):
        self.rollout_controller = None
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masspole + self.masscart)
        self.length = 0.5  # actually half the pole's length
        self.polemass_length = (self.masspole * self.length)
        self.force_mag = 10.0
        self.tau = 0.02  # seconds between state updates

        # Angle at which to fail the episode
        self.theta_threshold_radians = 15 * 2 * np.pi / 360
        self.x_threshold = 2.4

        self.action_space = (1, )
        self.observation_space = (4, )
        self.n, self.m = 4, 1
        self.viewer = None
        self._state = None

        def _dynamics(x_0, u):  # dynamics
            # x_0, u = np.squeeze(x_0, axis=1), np.squeeze(u, axis=1)
            x, x_dot, theta, theta_dot = np.split(x_0, 4)
            force = self.force_mag * np.clip(
                u, -1.0,
                1.0)  # iLQR may struggle with clipping due to lack of gradient
            costh = np.cos(theta)
            sinth = np.sin(theta)
            temp = (force + self.polemass_length * theta_dot * theta_dot *
                    sinth) / self.total_mass
            thetaacc = (self.gravity * sinth - costh * temp) / (
                self.length *
                (4.0 / 3.0 - self.masspole * costh * costh / self.total_mass))
            xacc = temp - self.polemass_length * thetaacc * costh / self.total_mass
            x = x + self.tau * x_dot  # use euler integration by default
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
            state = np.hstack((x, x_dot, theta, theta_dot))
            return state

        self._dynamics = jax.jit(
            _dynamics
        )  # MUST store as self._dynamics for default rollout implementation to work
        # C_x, C_u = (np.diag(np.array([0.2, 0.05, 1.0, 0.05])), np.diag(np.array([0.05])))
        # self._loss = jax.jit(lambda x, u: x.T @ C_x @ x + u.T @ C_u @ u) # MUST store as self._loss
        self._loss = kwargs.get('loss')

        # stack the jacobians of environment dynamics gradient
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        # jacobian[0], jacobian[1] = np.squeeze(jacobian[0], axis=(1,3)), np.squeeze(jacobian[1], axis=(1,3))
        self._dynamics_jacobian = jax.jit(
            lambda x, u: np.hstack(jacobian(x, u)))
        '''
        def dyn_jac_aux(x,u):
            j = jacobian(x,u)
            return np.hstack((np.squeeze(j[0], axis=(1,3)), np.squeeze(j[1], axis=(1,3))))
        self._dynamics_jacobian = jax.jit(dyn_jac_aux)'''

        # stack the gradients of environment loss
        loss_grad = jax.grad(self._loss, argnums=(0, 1))
        self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u)))

        # block the hessian of environment loss
        block_hessian = lambda A: np.vstack(
            [np.hstack([A[0][0], A[0][1]]),
             np.hstack([A[1][0], A[1][1]])])
        hessian = jax.hessian(self._loss, argnums=(0, 1))
        self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u)))
        '''
Esempio n. 20
0
from tfc.utils import TFCDict, Latex

# Initial solution
X0 = TFCDict({
    'x': np.array([2. / 3.]),
    'y': np.array([1. / 3.]),
    'z': np.array([1. / 3.])
})

# Create function, Jacobian, and Hessian
f = lambda X: np.squeeze(X['x'] * X['y'] * (X['x'] * X['y'] + 6. * X[
    'y'] - 8. * X['x'] - 48.) + X['z']**2 - 8. * X['z'] + 9. * X['y']**2 - 72.
                         * X['y'] + 16. * X['x']**2 + 96. * X['x'] + 160)
J = lambda X: np.hstack([val for val in jacfwd(f)(X).values()])
H = lambda X: np.hstack([val for val in jacrev(J)(X).values()])

# Equality constraint
A = np.array([[1., 2., -1.], [1., 0., 1.]])
b = np.array([[1.], [1.]])
N = sp.linalg.null_space(A)

# Iterate to find the solution (use a jax for loop)
X = X0
val = {'s': np.array([0.]), 'X0': X0, 'X': X, 'N': N}


def body(k, val):
    val['s'] += np.linalg.multi_dot([
        np.linalg.inv(np.linalg.multi_dot([val['N'].T,
                                           H(val['X']), val['N']])),
Esempio n. 21
0
else:
    model = pystan.StanModel(file=path + model_name + '.stan')
    with open(path + model_name + '.pkl', 'wb') as file:
        pickle.dump(model, file)

# load hmc warmup
inv_metric = pickle.load(open('stan_traces/inv_metric.pkl', 'rb'))
stepsize = pickle.load(open('stan_traces/step_size.pkl', 'rb'))
last_pos = pickle.load(open('stan_traces/last_pos.pkl', 'rb'))

# define MPC cost, gradient and hessian function and prepare to commpile
cost = jit(log_barrier_cosine_cost,
           static_argnums=(11, 12, 13, 14, 15))  # static argnums means it will recompile if N changes
gradient = jit(grad(log_barrier_cosine_cost, argnums=0), static_argnums=(
11, 12, 13, 14, 15))  # get compiled function to return gradients with respect to z (uc, s)
hessian = jit(jacfwd(jacrev(log_barrier_cosine_cost, argnums=0)), static_argnums=(11, 12, 13, 14, 15))

mu = 1e4
gamma = 1
delta = 0.05
max_iter = 5000

# declare some variables for storing the ongoing resutls
xt_est_save = np.zeros((Ns, Nx, T))
theta_est_save = np.zeros((Ns, 6, T))
q_est_save = np.zeros((Ns, 4, T))
r_est_save = np.zeros((Ns, 3, T))
uc_save = np.zeros((1, Nh, T))
mpc_result_save = []
hmc_traces_save = []
accept_rates = []
Esempio n. 22
0
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))
Esempio n. 23
0
        temp = np.sin(np.power(np.subtract(var[0], args[:, 0]), 3))
        ele1 = np.sum(temp)
        # sin(var[1]**3 - var[2]**2) * args[:,1]
        temp = np.sin(np.subtract(np.power(var[1], 3), np.power(var[2], 2)))
        ele2 = np.sum(np.dot(temp, args[:, 1]))
        # log(args[:,1]*args[:,2]*var[0]*var[1]*var[2])
        temp = np.multiply(args[:, 1], args[:, 2])
        ele3 = np.sum(
            np.log(1 + np.abs(np.dot(var[0] * var[1] * var[2], temp))))
        # print("\n\n", ele1, "\n", ele2, "\n", ele3)
        # print(res, type(res))
        return ele1 + ele2 + ele3

    # print("-------------Reverse mode---------------")
    # jax.jit(func)
    grad = jax.jit(jax.jacrev(func))

    for j in range(10):
        f.write("\tSubstep " + str(j + 1) + "\n")
        var = np.array([0.001 * j, 0.001 * 2 * j, 0.001 * 3 * j])
        print("\tSubstep:", str(j))
        start = time.time()
        gradient = grad(var)
        end = time.time()
        f.write("\t\t" + str(gradient) + "\n")
        print("\t\tGradient:", gradient)
        print("\t\tExecution time:", float(end - start))

# for ele in exetime:
#     for x in ele:
#         f.write(str(x)+" ")
Esempio n. 24
0
  def test_autodiff(self, get, same_inputs, phi):
    x1 = np.cos(random.normal(random.PRNGKey(1), (3, 1, 2, 3)))
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = np.cos(random.normal(random.PRNGKey(2), (4, 1, 2, 3)))

    name = phi.__name__
    if name == 'LeakyRelu':
      phi = phi(0.1)
    elif name == 'ElementwiseNumerical':
      phi = phi(fn=np.cos, deg=25)
    else:
      phi = phi()

    _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.01), phi,
                                  stax.Dense(1, 2., 0.01), phi)

    def k(x1, x2):
      return kernel_fn(x1, x2, get)

    dx1 = random.normal(random.PRNGKey(3), x1.shape) * 0.01
    if x2 is None:
      dx2 = None
    else:
      dx2 = random.normal(random.PRNGKey(4), x2.shape) * 0.01

    def dk(x1, x2):
      return jvp(k, (x1, x2), (dx1, dx2))[1]

    def d2k(x1, x2):
      return jvp(dk, (x1, x2), (dx1, dx2))[1]

    _dk = dk(x1, x2)
    _d2k = d2k(x1, x2)

    if same_inputs is not False and get == 'ntk' and 'Relu' in name:
      tol = 8e-3
    else:
      tol = 2e-3 if name == 'ElementwiseNumerical' else 1e-4

    def assert_close(x, y, tol=3e-5):
      if default_backend() == 'tpu':
        # TODO(romann): understand why TPUs have high errors.
        tol = 0.21
      self.assertLess(
          np.max(np.abs(x - y)) / (np.mean(np.abs(x)) + np.mean(np.abs(y))),
          tol)

    # k(x + dx) ~ k(x) + dk(x) dx + dx^T d2k(x) dx
    assert_close(k(x1 + dx1, None if same_inputs is None else x2 + dx2),
                 k(x1, x2) + _dk + _d2k / 2,
                 tol=tol)

    # d/dx1
    k_fwd_0 = jacfwd(k)(x1, x2)
    k_rev_0 = jacrev(k)(x1, x2)
    assert_close(k_fwd_0, k_rev_0)

    if same_inputs is not None:
      # d/dx2
      k_fwd_1 = jacfwd(k, 1)(x1, x2)
      k_rev_1 = jacrev(k, 1)(x1, x2)
      assert_close(k_fwd_1, k_rev_1)

      # dk(x2, x1)/dx2 = dk(x1, x2)/dx1
      k_fwd_01 = jacfwd(k, 1)(x2, x1)
      k_rev_01 = jacrev(k, 1)(x2, x1)
      assert_close(np.moveaxis(k_fwd_0, (0, 2, 4), (1, 3, 5)), k_fwd_01)
      assert_close(np.moveaxis(k_rev_0, (0, 2, 4), (1, 3, 5)), k_rev_01)

      # dk(x2, x1)/dx1 = dk(x1, x2)/dx2
      k_fwd_10 = jacfwd(k)(x2, x1)
      k_rev_10 = jacrev(k)(x2, x1)
      assert_close(np.moveaxis(k_fwd_1, (0, 2, 4), (1, 3, 5)), k_fwd_10)
      assert_close(np.moveaxis(k_rev_1, (0, 2, 4), (1, 3, 5)), k_rev_10)
Esempio n. 25
0
File: api_test.py Progetto: yyht/jax
 def test_complex_output_jacrev_raises_error(self):
     self.assertRaises(TypeError, lambda: jacrev(lambda x: np.sin(x))
                       (1 + 2j))
Esempio n. 26
0
 def fn1(th):
     xi = jp.einsum('ii...->i...', jax.jacrev(Ls)(th))
     _, prod = jax.jvp(Ls, (th,), (xi,))
     return prod
Esempio n. 27
0
 def lagrangian_grad_real_flat(x, u, H):
     z = real2comp(x).reshape(K, M)
     gg = jax.jacrev(lagrangian)
     grad = gg(z, u, H).conj()
     return comp2real(grad.reshape(K * M, ))
Esempio n. 28
0
 def fn1(th):
     xi = jp.lax.stop_gradient(jp.einsum('ii...->i...', jax.jacrev(Ls)(th)))
     _, prod = jax.jvp(Ls, (th,), (xi,))
     return (prod - jp.einsum('ij,ij->i', xi, xi))
    xf = index_update(xf, index[1:-1], np.exp(1.j * phis[0, :]))
    xf = index_update(xf, index[-1], 1)

    yf = np.zeros((dim, ), dtype='complex64')
    yf = index_update(yf, index[0], 1)
    yf = index_update(yf, index[1:-1], np.exp(1.j * phis[1, :]))
    yf = index_update(yf, index[-1], 1)
    # jax version of irfft assumes there is a nyquist frequency
    # they have not implemented it for odd dimensions
    ret = np.fft.irfft(xf**pos[0] * yf**pos[1])
    return ret


# grad_encode_point_x = grad(encode_2d_point_even, argnums=0)
# grad_encode_point_phis = grad(encode_2d_point_even, argnums=1)
grad_encode_point_x = jacrev(encode_2d_point_even, argnums=0)
grad_encode_point_phis = jacrev(encode_2d_point_even, argnums=1)

#print(encode_point(2., np.array([np.pi/2., np.pi/3.])))
#print(grad_encode_point_x(2., np.array([np.pi/2., np.pi/3.])))
#print(grad_encode_point_phis(2., np.array([np.pi/2., np.pi/3.])))

phi_vars = np.array([
    [np.pi / 2., 0.],
    [0., np.pi / 3.],
])
# phi_vars = np.array([np.pi/2.])
loc = (1.0, 0.5)

print(encode_2d_point_even(loc, phi_vars))
print("")
def cost_func_jacrev(bb, u):
    direcmat = jax.jacrev(cost_func, argnums=(0, ))(bb, u)
    return direcmat[0]