def update(opt_state, step_size): _, update_fun, get_params = optimizers.sgd(step_size) x = get_params(opt_state) g = grad(loss)(x) return update_fun(0, g, opt_state)
def test_grad(self, with_jit=False): x = np.float32(3.) res = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(tf.math.sin)))(x) self.assertAllClose(np.cos(x), res)
def update(params, opt_state, problem): g = jax.grad(prediction_loss)(params, problem) updates, opt_state = opt_update(g, opt_state) return optax.apply_updates(params, updates), opt_state
def test_gradient_log_normalizer( generator: Generator, distribution_info: DistributionInfo[Any, Any, Any]) -> None: """ Tests that the gradient log-normalizer evaluates to the same as the gradient of the log-normalizer. """ # pylint: disable=too-many-locals, disable=protected-access cls = type(distribution_info.nat_parameter_generator(generator, shape=())) original_ln = cls._original_log_normalizer original_gln = jit(grad(cls._original_log_normalizer)) optimized_ln = cls.log_normalizer optimized_gln = jit(grad(optimized_ln)) for _ in range(20): nat_parameters = distribution_info.nat_parameter_generator(generator, shape=()) kw_nat_parameters = nat_parameters.fixed_parameters_mapping() exp_parameters = nat_parameters.to_exp() # Regular transformation. nat_cls = type(nat_parameters) ep_cls = type(exp_parameters) # Original GLN. original_nat_parameters = original_gln(nat_parameters) f_original_nat_parameters = original_nat_parameters.flattened() original_exp_parameters = ep_cls.unflattened(f_original_nat_parameters, **kw_nat_parameters) # Optimized GLN. optimized_nat_parameters = optimized_gln(nat_parameters) f_optimized_nat_parameters = optimized_nat_parameters.flattened() optimized_exp_parameters = ep_cls.unflattened( f_optimized_nat_parameters, **kw_nat_parameters) # Test primal evaluation. assert_jax_allclose(exp_parameters, original_exp_parameters, rtol=1e-5) assert_jax_allclose(exp_parameters, optimized_exp_parameters, rtol=1e-5) # Test JVP. ones_like_nat_parameters = nat_cls( **{ name: jnp.zeros_like(value) for name, value in nat_parameters.fixed_parameters_mapping().items() }, **{ name: jnp.ones_like(value) for name, value in nat_parameters.parameters_name_value() }) original_gradients = jvp(original_ln, (nat_parameters, ), (ones_like_nat_parameters, )) optimized_gradients = jvp(optimized_ln, (nat_parameters, ), (ones_like_nat_parameters, )) assert_allclose(original_gradients, optimized_gradients, rtol=1.5e-5) # Test VJP. original_ln_of_nat, original_vjp = vjp(original_ln, nat_parameters) original_gln_of_nat, = original_vjp(1.0) optimized_ln_of_nat, optimized_vjp = vjp(optimized_ln, nat_parameters) optimized_gln_of_nat, = optimized_vjp(1.0) assert_jax_allclose(original_ln_of_nat, optimized_ln_of_nat, rtol=1e-5) for name, original_value in original_gln_of_nat.parameters_name_value( ): optimized_value = getattr(optimized_gln_of_nat, name) assert_jax_allclose(original_value, optimized_value, rtol=1e-5)
def update(params, opt_state, x, y_true, dropout_key): grads = grad(mean_cross_entropy)(params, x, y_true, dropout_key) updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state
def testIssue1789(self): def f(x): return random.gamma(random.PRNGKey(0), x) grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def main(unused_argv): numpts = 7 key = random.PRNGKey(0) eye = jnp.eye(numpts) def cov_map(cov_func, xs, xs2=None): """Compute a covariance matrix from a covariance function and data points. Args: cov_func: callable function, maps pairs of data points to scalars. xs: array of data points, stacked along the leading dimension. Returns: A 2d array `a` such that `a[i, j] = cov_func(xs[i], xs[j])`. """ if xs2 is None: return vmap(lambda x: vmap(lambda y: cov_func(x, y))(xs))(xs) else: return vmap(lambda x: vmap(lambda y: cov_func(x, y))(xs))(xs2).T def softplus(x): return jnp.logaddexp(x, 0.) # Note, writing out the vectorized form of the identity # ||x-y||^2 = <x-y,x-y> = ||x||^2 + ||y||^2 - 2<x,y> # for computing squared distances would be more efficient (but less succinct). def exp_quadratic(x1, x2): return jnp.exp(-jnp.sum((x1 - x2)**2)) def gp(params, x, y, xtest=None, compute_marginal_likelihood=False): noise = softplus(params['noise']) amp = softplus(params['amplitude']) ls = softplus(params['lengthscale']) ymean = jnp.mean(y) y = y - ymean x = x / ls train_cov = amp*cov_map(exp_quadratic, x) + eye * (noise + 1e-6) chol = scipy.linalg.cholesky(train_cov, lower=True) kinvy = scipy.linalg.solve_triangular( chol.T, scipy.linalg.solve_triangular(chol, y, lower=True)) if compute_marginal_likelihood: log2pi = jnp.log(2. * 3.1415) ml = jnp.sum( -0.5 * jnp.dot(y.T, kinvy) - jnp.sum(jnp.log(jnp.diag(chol))) - (numpts / 2.) * log2pi) ml -= jnp.sum(-0.5 * jnp.log(2 * 3.1415) - jnp.log(amp)**2) # lognormal prior return -ml if xtest is not None: xtest = xtest / ls cross_cov = amp*cov_map(exp_quadratic, x, xtest) mu = jnp.dot(cross_cov.T, kinvy) + ymean v = scipy.linalg.solve_triangular(chol, cross_cov, lower=True) var = (amp * cov_map(exp_quadratic, xtest) - jnp.dot(v.T, v)) return mu, var marginal_likelihood = partial(gp, compute_marginal_likelihood=True) predict = partial(gp, compute_marginal_likelihood=False) grad_fun = jit(grad(marginal_likelihood)) # Covariance hyperparameters to be learned params = {"amplitude": jnp.zeros((1, 1)), "noise": jnp.zeros((1, 1)) - 5., "lengthscale": jnp.zeros((1, 1))} momentums = dict([(k, p * 0.) for k, p in params.items()]) scales = dict([(k, p * 0. + 1.) for k, p in params.items()]) lr = 0.01 # Learning rate def train_step(params, momentums, scales, x, y): grads = grad_fun(params, x, y) for k in (params): momentums[k] = 0.9 * momentums[k] + 0.1 * grads[k][0] scales[k] = 0.9 * scales[k] + 0.1 * grads[k][0]**2 params[k] -= lr * momentums[k]/jnp.sqrt(scales[k] + 1e-5) return params, momentums, scales # Create a really simple toy 1D function y_fun = lambda x: jnp.sin(x) + 0.1 * random.normal(key, shape=(x.shape[0], 1)) x = (random.uniform(key, shape=(numpts, 1)) * 4.) + 1 y = y_fun(x) xtest = jnp.linspace(0, 6., 200)[:, None] for i in range(1000): params, momentums, scales = train_step(params, momentums, scales, x, y) if i % 50 == 0: ml = marginal_likelihood(params, x, y) print("Step: %d, neg marginal likelihood: %f" % (i, ml)) print(params) mu, var = predict(params, x, y, xtest) std = jnp.sqrt(jnp.diag(var)) plt.plot(x, y, "k.") plt.plot(xtest, mu) plt.fill_between(xtest.flatten(), mu.flatten() - std * 2, mu.flatten() + std * 2)
def test_grad_and_aux_nested(self): def f(x): g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x) return aux[0] f2 = lambda x: x**3 self.assertEqual(grad(f)(4.), grad(f2)(4.)) self.assertEqual(jit(grad(f))(4.), grad(f2)(4.)) self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.)) def f(x): g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x) return aux[0] * np.sin(x) f2 = lambda x: x**3 * np.sin(x) self.assertEqual(grad(f)(4.), grad(f2)(4.)) self.assertEqual(jit(grad(f))(4.), grad(f2)(4.)) self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
def grads_real(params, x): r = jax.grad(lambda pars, v: f_real_flat_scalar(pars, v).real)(params, x) i = jax.grad(lambda pars, v: f_real_flat_scalar(pars, v).imag)(params, x) return jax.lax.complex(r, i)
def compute_gradients_jax(w, X_test, y_test): print("Starting JAX demo") grad_jax = jax.grad(NLL)(w, (X_test, y_test)) print("grad {}".format(grad_jax)) return grad_jax
eq = dfdx_vect(params, x, t, bc1) + dfdt_vect(params, x, t, bc1) - (3 * x) - t bc1_res = f_vect(params, x - x, t, bc1) - bc1 return np.mean(eq**2) + np.mean(bc1_res**2) ########## ## Main ## ########## key = random.PRNGKey(0) params = random.normal(key, shape=(401, )) #-- Setting up the functions and derivatives --# dfdx = grad(f, 1) dfdt = grad(f, 2) f_vect = vmap(f, (None, 0, 0, 0)) dfdx_vect = vmap(dfdx, (None, 0, 0, 0)) dfdt_vect = vmap(dfdt, (None, 0, 0, 0)) grad_loss = jit(grad(loss, 0)) #-- Defining the domain of x, t, and bc1 --# x_values = np.linspace(-1, 1, num=40) t_values = np.linspace(-1, 1, num=40) bc1_values = np.linspace(1, 4, num=4) x = [] t = [] bc1 = []
def body_fun(i, opt_state): elbo_rng, data_rng = random.split(random.fold_in(rng, i)) batch = binarize_batch(data_rng, i, train_images) loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size g = grad(loss)(get_params(opt_state)) return opt_update(i, g, opt_state)
def initialize(self, n, m, h=64): """ Description: Randomly initialize the RNN. Args: n (int): Input dimension. m (int): Observation/output dimension. h (int): Default value 64. Hidden dimension of RNN. Returns: The first value in the time-series """ self.T = 0 self.initialized = True self.n, self.m, self.h = n, m, h glorot_init = stax.glorot( ) # returns a function that initializes weights self.W_h = glorot_init(generate_key(), (h, h)) self.W_u = glorot_init(generate_key(), (h, n)) self.W_out = glorot_init(generate_key(), (m, h)) self.b_h = np.zeros(h) self.hid = np.zeros(h) self.rollout_controller = None self.target = jax.random.uniform(generate_key(), shape=(self.m, ), minval=-1, maxval=1) ''' def _step(x, hid): next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h) y = np.dot(self.W_out, next_hid) return (next_hid, y)''' def _dynamics(hid, u): next_hid = np.tanh( np.dot(self.W_h, hid) + np.dot(self.W_u, u) + self.b_h) y = np.dot(self.W_out, next_hid) return (next_hid, y) # self._step = jax.jit(_step) self._dynamics = jax.jit(_dynamics) self._loss = lambda x, u: (self.target - self._dynamics(x, u))**2 # stack the jacobians of environment dynamics gradient jacobian = jax.jacrev(self._dynamics, argnums=(0, 1)) self._dynamics_jacobian = jax.jit( lambda x, u: np.hstack(jacobian(x, u))) # stack the gradients of environment loss loss_grad = jax.grad(self._loss, argnums=(0, 1)) self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u))) # block the hessian of environment loss block_hessian = lambda A: np.vstack( [np.hstack([A[0][0], A[0][1]]), np.hstack([A[1][0], A[1][1]])]) hessian = jax.hessian(self._loss, argnums=(0, 1)) self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u))) def _rollout(act, dyn, x_0, T): def f(x, i): u = act(x) x_next = dyn(x, u) return x_next, np.hstack((x, u)) _, trajectory = jax.lax.scan(f, x_0, np.arange(T)) return trajectory self._rollout = jax.jit(_rollout, static_argnums=(0, 1, 3)) return np.dot(self.W_out, self.hid)
def step(i, opt_state): p = get_params(opt_state) g = grad(self.negative_log_evidence)(p) return opt_update(i, g, opt_state)
def solve_sdp_dual_simple(verif_instance, key=None, opt=None, num_steps=10000, eval_every=1000, verbose=False, use_exact_eig_eval=True, use_exact_eig_train=False, n_iter_lanczos=100, kappa_reg_weight=None, kappa_zero_after=None, device_type=None): """Compute verified lower bound via dual of SDP relaxation. Args: verif_instance: a utils.SdpDualVerifInstance key: jax.random.PRNGKey, used for Lanczos opt: an optax.GradientTransformation instance, the optimizer. If None, defaults to Adam with learning rate 1e-3. num_steps: int, the number of outer loop optimization steps eval_every: int, frequency of running evaluation step verbose: bool, enables verbose logging use_exact_eig_eval: bool, whether to use exact eigendecomposition instead of Lanczos when computing evaluation loss use_exact_eig_train: bool, whether to use exact eigendecomposition instead of Lanczos during training n_iter_lanczos: int, number of Lanczos iterations kappa_reg_weight: float, adds a penalty of sum(abs(kappa_{1:N})) to loss, which regularizes kappa_{1:N} towards zero. Default None is disabled. kappa_zero_after: int, clamps kappa_{1:N} to zero after ``kappa_zero_after`` steps. Default None is disabled. device_type: string, used to clamp to a particular hardware device. Default None uses JAX default device placement Returns: A pair. The first element is a float, the final dual loss, which forms a valid upper bound on the objective specified by ``verif_instance``. The second element is a dict containing various debug info. """ assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type' assert isinstance(verif_instance, utils.SdpDualVerifInstance), 'invalid type' key = key if key is not None else jax.random.PRNGKey(0) opt = opt if opt is not None else optax.adam(1e3) dual_vars = jax.tree_map( lambda s: None if s is None else jnp.zeros(s), verif_instance.dual_shapes) dual_vars = init_duals_ibp(verif_instance, dual_vars) # Define loss function def loss(dual_vars, exact=use_exact_eig_train): return _loss(dual_vars, exact) @functools.partial(jax.jit, static_argnums=(1,), backend=device_type) def _loss(dual_var, exact): loss_val, step_info = dual_fun( verif_instance, dual_var, key, n_iter=n_iter_lanczos, exact=exact, include_info=True) step_info['loss_val'] = loss_val return loss_val, step_info # Define a compiled update step grad = jax.jit(jax.grad(loss, has_aux=True), backend=device_type) @functools.partial(jax.jit, backend=device_type) def grad_step(params, opt_state): g, info = grad(params) updates, new_opt_state = opt.update(g, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state, info # Optimize parameters in a loop opt_state = opt.init(dual_vars) info = collections.defaultdict(list) loss_log = [] best_loss = 1e9 # Main loop for i in range(num_steps): dual_vars, opt_state, step_info = grad_step(dual_vars, opt_state) loss_val = step_info['loss_val'] print(f'Iter {i}: Loss {loss_val}') best_loss = min(best_loss, loss_val) loss_log.append(loss_val) # Regularization of kappa if kappa_reg_weight is not None and kappa_reg_weight >= 0: onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1]) mask = jnp.ones_like(onehot) - onehot dual_vars[-1] -= mask * kappa_reg_weight if (kappa_zero_after is not None and kappa_zero_after >= 0 and i > kappa_zero_after): onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1]) dual_vars[-1] *= onehot dual_vars = project_duals(dual_vars, verif_instance.dual_types) if i % eval_every == 0: dual_val, _ = loss(dual_vars, exact=use_exact_eig_eval) info['steps'].append(i) info['loss_vals'].append(float(dual_val)) if verbose: print(f'Dual iter {i}: Train loss: {loss_val} Loss {dual_val}') final_loss = float(loss(dual_vars, exact=use_exact_eig_eval)[0]) info['final_dual_vars'] = dual_vars info['final_loss'] = final_loss info['loss_log'] = loss_log info['best_train_loss'] = best_loss return final_loss, info
def test_grad_and_aux_basic(self): g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.) self.assertAllClose(g, grad(lambda x: x**3)(3.), check_dtypes=True) self.assertAllClose(aux, [9.], check_dtypes=True)
def solve_sdp_dual(verif_instance, key=None, opt=None, num_steps=10000, verbose=False, eval_every=1000, use_exact_eig_eval=True, use_exact_eig_train=False, n_iter_lanczos=30, scl=-1.0, lr_init=1e-3, steps_per_anneal=100, anneal_factor=1.0, num_anneals=3, opt_name='adam', gd_momentum=0.9, add_diagnostic_stats=False, opt_multiplier_fn=None, init_dual_vars=None, init_opt_state=None, opt_dual_vars=None, kappa_reg_weight=None, kappa_zero_after=None, device_type=None, save_best_k=1): # pylint: disable=g-doc-return-or-yield, g-doc-args """Compute verified lower bound via dual of SDP relaxation. NOTE: This method exposes many hyperparameter options, and the method signature is subject to change. We instead suggest using ``solve_sdp_dual_simple`` instead if you need a stable interface. """ # NB: Whereas the rest of the code in this library is fairly top-down # readable, avoids excessive `if` statements, tries to make the code look # like the formalism, etc, this is not the case for this method. # This is essentially the outer loop, and includes all the debugging/logging/ # optimization tricks we need to get/debug good results. # # NB: Time profiling: On toy VerifInstances, JIT compilation dominates time # cost: JIT compilation takes ~12s, then we do ~3000 steps/sec. assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type' assert isinstance(verif_instance, utils.SdpDualVerifInstance), 'invalid type' key = key if key is not None else jax.random.PRNGKey(0) dual_vars = jax.tree_map( lambda s: None if s is None else jnp.zeros(s), verif_instance.dual_shapes) dual_vars = init_duals_ibp(verif_instance, dual_vars) if init_dual_vars is not None: # Casting, here for Colab. Essentially same as `dual_vars = init_dual_vars` dual_vars = utils.structure_like(init_dual_vars, dual_vars) if opt_dual_vars is not None: opt_dual_vars = utils.structure_like(opt_dual_vars, dual_vars) # Create optimizer if opt is None: if (isinstance(steps_per_anneal, float) or isinstance(steps_per_anneal, int)): anneal_steps = [steps_per_anneal*(i+1) for i in range(num_anneals)] else: anneal_steps = np.cumsum(steps_per_anneal) anneal_steps = jnp.array(anneal_steps) def lr_schedule(t): cur_epoch = jnp.minimum(num_anneals, jnp.sum(t > anneal_steps)) return lr_init * jnp.float_power(anneal_factor, cur_epoch) opt_class = getattr(optax, opt_name) base_opt = (opt_class(1., momentum=gd_momentum) if opt_name == 'sgd' else opt_class(1.)) opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule)) if opt_multiplier_fn: # NB: Interface very specific to tree.map_structure_with_path # Example: opt_multiplier_fn=lambda path: 0.1 if 'lam' in path else 1.0 opt_multipliers = tree.map_structure_with_path( lambda path, v: opt_multiplier_fn(path), dual_vars) opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule), utils.scale_by_variable_opt(opt_multipliers)) else: opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule)) # Define loss function def loss(dual_vars, loss_scl=scl, exact=use_exact_eig_train): return _loss(dual_vars, loss_scl, exact) @functools.partial(jax.jit, static_argnums=(1, 2), backend=device_type) def _loss(dual_var, loss_scl, exact): loss_val, step_info = dual_fun( verif_instance, dual_var, key, n_iter=n_iter_lanczos, exact=exact, scl=loss_scl, include_info=True) step_info['loss_val'] = loss_val return loss_val, step_info # Define a compiled update step grad = jax.jit(jax.grad(loss, has_aux=True), backend=device_type) @functools.partial(jax.jit, backend=device_type) def grad_step(params, opt_state): g, info = grad(params) updates, new_opt_state = opt.update(g, opt_state) new_params = optax.apply_updates(params, updates) info['g'] = g info['updates'] = updates return new_params, new_opt_state, info # Optimize parameters in a loop opt_state = opt.init(dual_vars) if init_opt_state: opt_state = utils.structure_like(init_opt_state, opt_state) info = collections.defaultdict(list) loss_log = [] store_best = [] recent_eig_vecs = collections.deque(maxlen=10) best_loss = 1e9 last_H = None start_i = 0 # Main loop for i in range(start_i, num_steps): dual_vars_prev = dual_vars dual_vars, opt_state, step_info = grad_step(dual_vars, opt_state) loss_val = step_info['loss_val'] print(f'Iter {i}: Loss {loss_val}') best_loss = min(best_loss, loss_val) if add_diagnostic_stats: info['dual_vars'].append(dual_vars_prev) eig_vec = step_info['eig_vec'] cosine_sims = [] for prev_eig_vec in recent_eig_vecs: denom = jnp.sqrt(jnp.linalg.norm(eig_vec)*jnp.linalg.norm(prev_eig_vec)) eig_sim = jnp.sum(prev_eig_vec * eig_vec) / denom cosine_sims.append(abs(float(eig_sim))) info['c_lambda'].append(float(step_info['c_lambda'])) info['past_10_cosine_sims'].append(np.array(cosine_sims)) info['g'].append(step_info['g']) info['updates'].append(step_info['updates']) if use_exact_eig_train: # The info is for -H, so to get smallest for H, take negative of max eig_vals = -step_info['eig_info'][0][-1:-20:-1] cur_H = step_info['eig_info'][2] diff_H = 0 if last_H is None else np.linalg.norm(cur_H - last_H) last_H = cur_H info['diff_H'].append(float(diff_H)) info['smallest_20_eig_vals'].append(eig_vals) recent_eig_vecs.appendleft(eig_vec) loss_log.append(loss_val) if len(store_best) < save_best_k: store_best.append((loss_val, dual_vars_prev)) store_best.sort(key=lambda x: x[0]) elif loss_val < store_best[-1][0]: store_best[-1] = (loss_val, dual_vars_prev) store_best.sort(key=lambda x: x[0]) # Regularization of kappa if kappa_reg_weight is not None and kappa_reg_weight >= 0: onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1]) mask = jnp.ones_like(onehot) - onehot dual_vars[-1] -= mask * kappa_reg_weight if (kappa_zero_after is not None and kappa_zero_after >= 0 and i > kappa_zero_after): onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1]) dual_vars[-1] *= onehot dual_vars = project_duals(dual_vars, verif_instance.dual_types) if opt_dual_vars: distance_to_opt = jax.tree_multimap(lambda x, y: jnp.linalg.norm(x - y), dual_vars, opt_dual_vars) info['distance_to_opt'].append(distance_to_opt) if i % eval_every == 0: dual_val, _ = loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval) info['steps'].append(i) info['loss_vals'].append(float(dual_val)) if verbose: print(f'Dual iter {i}: Train loss: {loss_val} Loss {dual_val}') final_loss = float(loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval)[0]) info['final_dual_vars'] = dual_vars info['final_opt_state'] = opt_state info['final_loss'] = final_loss info['loss_log'] = loss_log info['store_best'] = store_best return final_loss, info
def f(x): g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x) return aux[0] * np.sin(x)
def dual_fun(verif_instance, dual_vars, key=None, n_iter=30, scl=-1, exact=False, dynamic_unroll=True, include_info=False): # pylint: disable=invalid-name """Returns the dual objective value. Args: verif_instance: a utils.SdpDualVerifInstance, the verification problem dual_vars: A list of dual variables at each layer key: PRNGKey passed to Lanczos n_iter: Number of Lanczos iterations to use scl: Inverse temperature in softmax over eigenvalues to smooth optimization problem (if negative treat as hardmax) exact: Whether to use exact eigendecomposition instead of Lanczos dynamic_unroll: bool. Whether to use jax.fori_loop for Lanczos for faster JIT compilation. Default is False. include_info: if True, also return an `info` dict of various other values computed for the objective Returns: Either a single float, the dual upper bound, or if ``include_info=True``, returns a pair, the dual bound and a dict containing debugging info """ key = key if key is not None else jax.random.PRNGKey(0) assert isinstance(verif_instance, utils.SdpDualVerifInstance) bounds = verif_instance.bounds layer_sizes = utils.layer_sizes_from_bounds(bounds) layer_sizes_1d = [np.prod(np.array(i)) for i in layer_sizes] N = sum(layer_sizes_1d) + 1 info = {} # Mean activations at each layer activations_center = [(b.lb + b.ub) / 2 for b in bounds] # Maximum deviation from mean activations radius = [(b.ub - b.lb) / 2 for b in bounds] inner_lagrangian = verif_instance.make_inner_lagrangian(dual_vars) lagrangian = _make_transformed_lagrangian( inner_lagrangian, activations_center, radius) # Construct c_lambda and g_lambda terms zeros = [jnp.zeros(sz) for sz in layer_sizes] c_lambda = lagrangian(zeros) g_lambda = jax.grad(lagrangian)(zeros) g_lambda = flatten(g_lambda) info['c_lambda'] = c_lambda def Hv(v): """Hessian-vector product for H_lambda - refer to docstring for `Av()`.""" lag_grad = lambda v2: flatten(jax.grad(lagrangian)(v2)) hv_v = jax.grad(lambda v2: jnp.vdot(lag_grad(v2), v))(zeros) hv_flat = flatten(hv_v) return hv_flat def Av(v): """Matrix-vector product. Args: v: vector, DeviceArray Returns: Av: vector, Device array. A is defined as diag(kappa) - M(lambda) where M(lambda) = [0, g_lambda'; g_lambda, H_lambda], and these terms correspond to L~(z) = c_lambda + g_lambda' z + z' H_lambda z """ # Expand Mv=[0 g'; g H] [v0;v1] = [g'v1; v0*g + H(v1)] = [Mv0;Mv1] # Compute Mv0 term mv_zero = jnp.reshape(jnp.vdot(g_lambda, v[1:]), (1,)) # Compute Mv1 term mv_rest = Hv(v[1:]) + v[0] * g_lambda mv = jnp.concatenate([mv_zero, mv_rest], axis=0) diag_kappa_v = jnp.reshape(dual_vars[-1], mv.shape) * v av = diag_kappa_v - mv return jnp.reshape(av, v.shape) # Construct dual function (dual_vars[-1]=kappa) if exact: eig_vec, eig_info = eigenvector_utils.min_eigenvector_exact( Av, N, scl=scl, report_all=True) info['eig_info'] = eig_info else: eig_vec = eigenvector_utils.min_eigenvector_lanczos( Av, N, min(N, n_iter), key, scl, dynamic_unroll=dynamic_unroll) info['eig_vec'] = eig_vec info['kappa'] = dual_vars[-1] hess_val = jnp.vdot(eig_vec, Av(eig_vec))/(jnp.vdot(eig_vec, eig_vec)) hess_val = jnp.reshape(hess_val, ()) # Form dual objective lambda_minus = jnp.minimum(hess_val, 0.) kappa_hat = jnp.maximum(0, dual_vars[-1] - lambda_minus) dual_val = c_lambda + 0.5 * jnp.sum(kappa_hat) if include_info: return dual_val, info return dual_val
def update_derivative(i, opt_state, batch, l2reg): params = get_params(opt_state) return opt_update(i, jax.grad(loss, 0)(params, batch, l2reg), opt_state), params
def Hv(v): """Hessian-vector product for H_lambda - refer to docstring for `Av()`.""" lag_grad = lambda v2: flatten(jax.grad(lagrangian)(v2)) hv_v = jax.grad(lambda v2: jnp.vdot(lag_grad(v2), v))(zeros) hv_flat = flatten(hv_v) return hv_flat
def gradients(scalar, variables): """Compute the gradients of a scalar w.r.t to a given list of variables. Arguments --------- scalar: :class:`symjax.tensor.base.Tensor` the variable to differentiate variables: List or Tuple the variables used to compute the derivative. Returns ------- gradients: Tuple the sequency of gradients ordered as given in the input variables Example ------- .. doctest:: >>> import symjax >>> w = symjax.tensor.ones(3) >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> l = (w ** 2).sum() * x >>> g = symjax.gradients(l, [w])[0] >>> f = symjax.function(outputs=g, updates={x:x + 1}) >>> for i in range(2): ... print(f()) [4. 4. 4.] [6. 6. 6.] """ if numpy.prod(scalar.shape) != 1: raise RuntimeError("the variable to differentiate is not a scalar") if not isinstance(scalar, t.Tensor): raise RuntimeError( "the variable used in gradients should be a Tensor type") if scalar.shape != (): scalar = scalar.sum() if isinstance(variables, t.Tensor): input_variables = [variables] input_list = False else: input_variables = variables input_list = True # get the argnum of the variables that we differentiate one argnums = list(range(len(input_variables))) # get all the roots of the scalar, this is needed as otherwise they are not # as the input of the gradient function and thus a change of # their value will not change the gradient computation, we also ensure # uniqueness input_variables += [ i for i in current_graph().roots(scalar) if i not in input_variables ] # create a dummy function that is needed for jax to compute a gradient func # this function is the one that builds the graph of computation from all # roots # to the scalar varible s.t. automatic diffenrentiation can be applied def fn(*args): return current_graph().get(scalar, dict(zip(input_variables, list(args)))) # now we obtain the grad function. In fact, Jax returns a function that, # when it is called, returns the gradient values, this function is then # used to generate the Tuple of symbolic variables grad_fn = jax.grad(fn, argnums) wrap_fn = t.jax_wrap(grad_fn) if input_list: return wrap_fn(*input_variables) else: return wrap_fn(*input_variables)[0]
def id_func(x): return lambda u: jnp.dot(jnp.identity(x.shape[0]), u) def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1] def breg_bound(vec, lb=-1.0, ub=1.0, *args, **kwargs): return jnp.sum((-vec + ub) * jnp.log(-vec + ub) + (vec - lb) * jnp.log(vec - lb)) DP_bound = jax.grad(breg_bound, 0) def DP_inv_bound(vec, lb=-1.0, ub=1.0): return (ub * jnp.exp(vec) + lb) / (1 + jnp.exp(vec)) def D2P_bound(vec, lb=-1.0, ub=1.0): def out(u): return jvp(lambda x: DP_bound(x, lb, ub), (vec, ), (u, ))[1] return out def inv_D2P_bound(vec, lb=-1.0, ub=1.0): if len(jnp.shape(vec)) <= 1:
print("The input is", val) print("The result of Rosenbrock's is ", result) fun_driver(10) """### 3. First look at derivatives: `jax.grad()` 1. https://jax.readthedocs.io/en/latest/jax.html#jax.grad 2. Computes $v\cdot J$ for a function that computes a scalar value ($R^n \rightarrow R$). 3. The seed is internally set to `1.0` (the shape of the seed must match the primal output). """ #Create a function that computes the derivatives. #This needs to happen only once. grad_rosenbrock = jax.grad(rosenbrock) def grad_driver(n): """ Input: n array length Output: Result of Rosenbrock's banana function """ #create the input array val = jnp.full(n, 0.5) #compute the derivatives result = grad_rosenbrock(val) plot_vals(val, grad=result) print("The input is", val)
def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1]
import jax.random as random key = random.PRNGKey(0) if __name__ == '__main__': # 4. Minimize for a few steps initial_radii = np.ones(len(all_elements)) * 0.12 initial_scales = np.ones(len(all_elements)) * 0.85 theta0 = pack(initial_radii, initial_scales) print('initial theta', theta0) initial_log_prob = log_prob(theta0) print('initial log prob', initial_log_prob) # TODO: potentially jit() this grad_log_prob = grad(log_prob) #grad_log_prob = parallel_grad_log_prob print('initial gradient norm = {}'.format( np.linalg.norm(grad_log_prob(theta0)))) minimized_theta_fname = os.path.join( data_path, 'elemental_types_l-bfgs_freesolv_{}.npy'.format(name)) print('minimizing...') from scipy.optimize import minimize loss = lambda theta: -log_prob(theta) grad_loss = lambda theta: -grad_log_prob(theta) bounds = [(0.001, 2.0)] * len(theta0) result = minimize(loss,
def test_grad_tuple_output(self): jtu.check_raises( lambda: grad(lambda x: (x, x))(1.0), TypeError, "Gradient only defined for scalar-output functions. ")
def test_system_derivatives(self): sdf_file = open("examples/host-acd.mol2").read() smirnoff = ForceField("test_forcefields/smirnoff99Frosst.offxml") mol = Chem.MolFromMol2Block(sdf_file, sanitize=True, removeHs=False, cleanupSubstructures=True) guest_potentials, guest_params, guest_param_groups, guest_conf, guest_masses = forcefield.parameterize( mol, smirnoff) ref_nrgs = [] test_nrgs = [] for potential, params in guest_potentials: jax_potential = potential_map[potential] if potential == timemachine.lib.custom_ops.HarmonicBond_f64: jp = functools.partial(jax_potential, box=None, bond_idxs=params[0], param_idxs=params[1]) elif potential == timemachine.lib.custom_ops.HarmonicAngle_f64: jp = functools.partial(jax_potential, box=None, angle_idxs=params[0], param_idxs=params[1]) elif potential == timemachine.lib.custom_ops.PeriodicTorsion_f64: jp = functools.partial(jax_potential, box=None, torsion_idxs=params[0], param_idxs=params[1]) elif potential == timemachine.lib.custom_ops.LennardJones_f64: jp = functools.partial(jax_potential, box=None, scale_matrix=params[0], param_idxs=params[1]) elif potential == timemachine.lib.custom_ops.Electrostatics_f64: jp = functools.partial(jax_potential, box=None, scale_matrix=params[0], param_idxs=params[1]) else: raise ValueError("unknown functional form") test_nrgs.append(potential(*params)) ref_nrgs.append(jp) def ref_total_nrg(conf, params): nrgs = [] for p in ref_nrgs: nrgs.append(p(conf, params)) return jnp.sum(nrgs) dp_idxs = onp.arange(len(params)).astype(onp.int32) def test_total_nrg(conf, params): nrgs = [] for p in test_nrgs: res = p.derivatives(onp.expand_dims(conf, axis=0), params, dp_idxs) nrgs.append(res[0]) return onp.sum(nrgs) num_atoms = guest_conf.shape[0] ref_e = ref_total_nrg(guest_conf, guest_params) test_e = test_total_nrg(guest_conf, guest_params) onp.testing.assert_almost_equal(ref_e, test_e) dt = 1e-3 ca = 0.5 cb = onp.random.rand(num_atoms) / 10 cc = onp.zeros(num_atoms) intg = ReferenceLangevin(dt, ca, cb, cc) ref_dE_dx_fn = jax.grad(ref_total_nrg, argnums=(0, )) ref_dE_dx_fn = jax.jit(ref_dE_dx_fn) def integrate(x_t, v_t, params): for _ in range(100): x_t, v_t = intg.step(x_t, v_t, ref_dE_dx_fn(x_t, params)[0]) return x_t, v_t v0 = onp.random.rand(num_atoms * 3).reshape(num_atoms, 3) x_f, v_f = integrate(guest_conf, v0, guest_params) lo = custom_ops.LangevinOptimizer_f64(dt, ca, cb, cc) ctxt = custom_ops.Context_f64(test_nrgs, lo, guest_params, guest_conf, v0, dp_idxs) for _ in range(100): ctxt.step() onp.testing.assert_almost_equal(x_f, ctxt.get_x()) onp.testing.assert_almost_equal(v_f, ctxt.get_v()) grad_fn = jax.jacfwd(integrate, argnums=(2)) dx_dp_f, dv_dp_f = grad_fn(guest_conf, v0, guest_params) dx_dp_f = onp.asarray(onp.transpose(dx_dp_f, (2, 0, 1))) dv_dp_f = onp.asarray(onp.transpose(dv_dp_f, (2, 0, 1))) onp.testing.assert_almost_equal(dx_dp_f[dp_idxs], ctxt.get_dx_dp()) onp.testing.assert_almost_equal(dv_dp_f[dp_idxs], ctxt.get_dv_dp())
def test_grad_nonscalar_output(self): jtu.check_raises( lambda: grad(lambda x: x)(onp.zeros(3)), TypeError, "Gradient only defined for scalar-output functions. ")
import jax.numpy as np from jax import grad, jit, vmap from functools import partial def predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs def logprob_fun(params, inputs, targets): preds = predict(params, inputs) return np.sum((preds - targets)**2) grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function perex_grads = jit(lambda params, inputs, targets: # fast per-example gradients vmap(partial(grad_fun, params), inputs, targets))
def test_holomorphic_grad(self): out = grad(lambda x: np.sin(x), holomorphic=True)(1 + 2j) expected = 2.0327230070196656 - 3.0518977991518j self.assertAllClose(out, expected, check_dtypes=False)