Beispiel #1
0
 def update(opt_state, step_size):
     _, update_fun, get_params = optimizers.sgd(step_size)
     x = get_params(opt_state)
     g = grad(loss)(x)
     return update_fun(0, g, opt_state)
Beispiel #2
0
 def test_grad(self, with_jit=False):
   x = np.float32(3.)
   res = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(tf.math.sin)))(x)
   self.assertAllClose(np.cos(x), res)
Beispiel #3
0
 def update(params, opt_state, problem):
     g = jax.grad(prediction_loss)(params, problem)
     updates, opt_state = opt_update(g, opt_state)
     return optax.apply_updates(params, updates), opt_state
Beispiel #4
0
def test_gradient_log_normalizer(
        generator: Generator,
        distribution_info: DistributionInfo[Any, Any, Any]) -> None:
    """
    Tests that the gradient log-normalizer evaluates to the same as the gradient of the
    log-normalizer.
    """
    # pylint: disable=too-many-locals, disable=protected-access
    cls = type(distribution_info.nat_parameter_generator(generator, shape=()))
    original_ln = cls._original_log_normalizer
    original_gln = jit(grad(cls._original_log_normalizer))
    optimized_ln = cls.log_normalizer
    optimized_gln = jit(grad(optimized_ln))

    for _ in range(20):
        nat_parameters = distribution_info.nat_parameter_generator(generator,
                                                                   shape=())
        kw_nat_parameters = nat_parameters.fixed_parameters_mapping()
        exp_parameters = nat_parameters.to_exp()  # Regular transformation.
        nat_cls = type(nat_parameters)
        ep_cls = type(exp_parameters)

        # Original GLN.
        original_nat_parameters = original_gln(nat_parameters)
        f_original_nat_parameters = original_nat_parameters.flattened()
        original_exp_parameters = ep_cls.unflattened(f_original_nat_parameters,
                                                     **kw_nat_parameters)

        # Optimized GLN.
        optimized_nat_parameters = optimized_gln(nat_parameters)
        f_optimized_nat_parameters = optimized_nat_parameters.flattened()
        optimized_exp_parameters = ep_cls.unflattened(
            f_optimized_nat_parameters, **kw_nat_parameters)

        # Test primal evaluation.
        assert_jax_allclose(exp_parameters, original_exp_parameters, rtol=1e-5)
        assert_jax_allclose(exp_parameters,
                            optimized_exp_parameters,
                            rtol=1e-5)

        # Test JVP.
        ones_like_nat_parameters = nat_cls(
            **{
                name: jnp.zeros_like(value)
                for name, value in
                nat_parameters.fixed_parameters_mapping().items()
            }, **{
                name: jnp.ones_like(value)
                for name, value in nat_parameters.parameters_name_value()
            })
        original_gradients = jvp(original_ln, (nat_parameters, ),
                                 (ones_like_nat_parameters, ))
        optimized_gradients = jvp(optimized_ln, (nat_parameters, ),
                                  (ones_like_nat_parameters, ))
        assert_allclose(original_gradients, optimized_gradients, rtol=1.5e-5)

        # Test VJP.
        original_ln_of_nat, original_vjp = vjp(original_ln, nat_parameters)
        original_gln_of_nat, = original_vjp(1.0)
        optimized_ln_of_nat, optimized_vjp = vjp(optimized_ln, nat_parameters)
        optimized_gln_of_nat, = optimized_vjp(1.0)
        assert_jax_allclose(original_ln_of_nat, optimized_ln_of_nat, rtol=1e-5)
        for name, original_value in original_gln_of_nat.parameters_name_value(
        ):
            optimized_value = getattr(optimized_gln_of_nat, name)
            assert_jax_allclose(original_value, optimized_value, rtol=1e-5)
Beispiel #5
0
 def update(params, opt_state, x, y_true, dropout_key):
     grads = grad(mean_cross_entropy)(params, x, y_true, dropout_key)
     updates, opt_state = opt.update(grads, opt_state, params)
     params = optax.apply_updates(params, updates)
     return params, opt_state
Beispiel #6
0
    def testIssue1789(self):
        def f(x):
            return random.gamma(random.PRNGKey(0), x)

        grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
Beispiel #7
0
def main(unused_argv):

  numpts = 7
  key = random.PRNGKey(0)
  eye = jnp.eye(numpts)

  def cov_map(cov_func, xs, xs2=None):
    """Compute a covariance matrix from a covariance function and data points.

    Args:
      cov_func: callable function, maps pairs of data points to scalars.
      xs: array of data points, stacked along the leading dimension.
    Returns:
      A 2d array `a` such that `a[i, j] = cov_func(xs[i], xs[j])`.
    """
    if xs2 is None:
      return vmap(lambda x: vmap(lambda y: cov_func(x, y))(xs))(xs)
    else:
      return vmap(lambda x: vmap(lambda y: cov_func(x, y))(xs))(xs2).T

  def softplus(x):
    return jnp.logaddexp(x, 0.)

  # Note, writing out the vectorized form of the identity
  # ||x-y||^2 = <x-y,x-y> = ||x||^2 + ||y||^2 - 2<x,y>
  # for computing squared distances would be more efficient (but less succinct).
  def exp_quadratic(x1, x2):
    return jnp.exp(-jnp.sum((x1 - x2)**2))

  def gp(params, x, y, xtest=None, compute_marginal_likelihood=False):
    noise = softplus(params['noise'])
    amp = softplus(params['amplitude'])
    ls = softplus(params['lengthscale'])
    ymean = jnp.mean(y)
    y = y - ymean
    x = x / ls
    train_cov = amp*cov_map(exp_quadratic, x) + eye * (noise + 1e-6)
    chol = scipy.linalg.cholesky(train_cov, lower=True)
    kinvy = scipy.linalg.solve_triangular(
        chol.T, scipy.linalg.solve_triangular(chol, y, lower=True))
    if compute_marginal_likelihood:
      log2pi = jnp.log(2. * 3.1415)
      ml = jnp.sum(
          -0.5 * jnp.dot(y.T, kinvy) -
          jnp.sum(jnp.log(jnp.diag(chol))) -
          (numpts / 2.) * log2pi)
      ml -= jnp.sum(-0.5 * jnp.log(2 * 3.1415) - jnp.log(amp)**2) # lognormal prior
      return -ml

    if xtest is not None:
      xtest = xtest / ls
    cross_cov = amp*cov_map(exp_quadratic, x, xtest)
    mu = jnp.dot(cross_cov.T, kinvy) + ymean
    v = scipy.linalg.solve_triangular(chol, cross_cov, lower=True)
    var = (amp * cov_map(exp_quadratic, xtest) - jnp.dot(v.T, v))
    return mu, var

  marginal_likelihood = partial(gp, compute_marginal_likelihood=True)
  predict = partial(gp, compute_marginal_likelihood=False)
  grad_fun = jit(grad(marginal_likelihood))

  # Covariance hyperparameters to be learned
  params = {"amplitude": jnp.zeros((1, 1)),
            "noise": jnp.zeros((1, 1)) - 5.,
            "lengthscale": jnp.zeros((1, 1))}
  momentums = dict([(k, p * 0.) for k, p in params.items()])
  scales = dict([(k, p * 0. + 1.) for k, p in params.items()])

  lr = 0.01  # Learning rate
  def train_step(params, momentums, scales, x, y):
    grads = grad_fun(params, x, y)
    for k in (params):
      momentums[k] = 0.9 * momentums[k] + 0.1 * grads[k][0]
      scales[k] = 0.9 * scales[k] + 0.1 * grads[k][0]**2
      params[k] -= lr * momentums[k]/jnp.sqrt(scales[k] + 1e-5)
    return params, momentums, scales

  # Create a really simple toy 1D function
  y_fun = lambda x: jnp.sin(x) + 0.1 * random.normal(key, shape=(x.shape[0], 1))
  x = (random.uniform(key, shape=(numpts, 1)) * 4.) + 1
  y = y_fun(x)
  xtest = jnp.linspace(0, 6., 200)[:, None]

  for i in range(1000):
    params, momentums, scales = train_step(params, momentums, scales, x, y)
    if i % 50 == 0:
      ml = marginal_likelihood(params, x, y)
      print("Step: %d, neg marginal likelihood: %f" % (i, ml))

  print(params)
  mu, var = predict(params, x, y, xtest)
  std = jnp.sqrt(jnp.diag(var))
  plt.plot(x, y, "k.")
  plt.plot(xtest, mu)
  plt.fill_between(xtest.flatten(),
                    mu.flatten() - std * 2, mu.flatten() + std * 2)
Beispiel #8
0
    def test_grad_and_aux_nested(self):
        def f(x):
            g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
            return aux[0]

        f2 = lambda x: x**3

        self.assertEqual(grad(f)(4.), grad(f2)(4.))
        self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
        self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))

        def f(x):
            g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
            return aux[0] * np.sin(x)

        f2 = lambda x: x**3 * np.sin(x)

        self.assertEqual(grad(f)(4.), grad(f2)(4.))
        self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
        self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
Beispiel #9
0
def grads_real(params, x):
    r = jax.grad(lambda pars, v: f_real_flat_scalar(pars, v).real)(params, x)
    i = jax.grad(lambda pars, v: f_real_flat_scalar(pars, v).imag)(params, x)
    return jax.lax.complex(r, i)
def compute_gradients_jax(w, X_test, y_test):
    print("Starting JAX demo")
    grad_jax = jax.grad(NLL)(w, (X_test, y_test))
    print("grad {}".format(grad_jax))
    return grad_jax
    eq = dfdx_vect(params, x, t, bc1) + dfdt_vect(params, x, t,
                                                  bc1) - (3 * x) - t
    bc1_res = f_vect(params, x - x, t, bc1) - bc1
    return np.mean(eq**2) + np.mean(bc1_res**2)


##########
## Main ##
##########

key = random.PRNGKey(0)
params = random.normal(key, shape=(401, ))

#-- Setting up the functions and derivatives --#

dfdx = grad(f, 1)
dfdt = grad(f, 2)
f_vect = vmap(f, (None, 0, 0, 0))
dfdx_vect = vmap(dfdx, (None, 0, 0, 0))
dfdt_vect = vmap(dfdt, (None, 0, 0, 0))
grad_loss = jit(grad(loss, 0))

#-- Defining the domain of x, t, and bc1 --#

x_values = np.linspace(-1, 1, num=40)
t_values = np.linspace(-1, 1, num=40)
bc1_values = np.linspace(1, 4, num=4)

x = []
t = []
bc1 = []
Beispiel #12
0
 def body_fun(i, opt_state):
     elbo_rng, data_rng = random.split(random.fold_in(rng, i))
     batch = binarize_batch(data_rng, i, train_images)
     loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
     g = grad(loss)(get_params(opt_state))
     return opt_update(i, g, opt_state)
Beispiel #13
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_u = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)

        self.rollout_controller = None
        self.target = jax.random.uniform(generate_key(),
                                         shape=(self.m, ),
                                         minval=-1,
                                         maxval=1)
        '''
        def _step(x, hid):
            next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)'''
        def _dynamics(hid, u):
            next_hid = np.tanh(
                np.dot(self.W_h, hid) + np.dot(self.W_u, u) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)

        # self._step = jax.jit(_step)
        self._dynamics = jax.jit(_dynamics)

        self._loss = lambda x, u: (self.target - self._dynamics(x, u))**2

        # stack the jacobians of environment dynamics gradient
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        self._dynamics_jacobian = jax.jit(
            lambda x, u: np.hstack(jacobian(x, u)))

        # stack the gradients of environment loss
        loss_grad = jax.grad(self._loss, argnums=(0, 1))
        self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u)))

        # block the hessian of environment loss
        block_hessian = lambda A: np.vstack(
            [np.hstack([A[0][0], A[0][1]]),
             np.hstack([A[1][0], A[1][1]])])
        hessian = jax.hessian(self._loss, argnums=(0, 1))
        self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u)))

        def _rollout(act, dyn, x_0, T):
            def f(x, i):
                u = act(x)
                x_next = dyn(x, u)
                return x_next, np.hstack((x, u))

            _, trajectory = jax.lax.scan(f, x_0, np.arange(T))
            return trajectory

        self._rollout = jax.jit(_rollout, static_argnums=(0, 1, 3))
        return np.dot(self.W_out, self.hid)
Beispiel #14
0
 def step(i, opt_state):
     p = get_params(opt_state)
     g = grad(self.negative_log_evidence)(p)
     return opt_update(i, g, opt_state)
Beispiel #15
0
def solve_sdp_dual_simple(verif_instance, key=None, opt=None, num_steps=10000,
                          eval_every=1000, verbose=False,
                          use_exact_eig_eval=True, use_exact_eig_train=False,
                          n_iter_lanczos=100,
                          kappa_reg_weight=None, kappa_zero_after=None,
                          device_type=None):
  """Compute verified lower bound via dual of SDP relaxation.

  Args:
    verif_instance: a utils.SdpDualVerifInstance
    key: jax.random.PRNGKey, used for Lanczos
    opt: an optax.GradientTransformation instance, the optimizer.
      If None, defaults to Adam with learning rate 1e-3.
    num_steps: int, the number of outer loop optimization steps
    eval_every: int, frequency of running evaluation step
    verbose: bool, enables verbose logging
    use_exact_eig_eval: bool, whether to use exact eigendecomposition instead of
      Lanczos when computing evaluation loss
    use_exact_eig_train: bool, whether to use exact eigendecomposition instead
      of Lanczos during training
    n_iter_lanczos: int, number of Lanczos iterations
    kappa_reg_weight: float, adds a penalty of sum(abs(kappa_{1:N})) to loss,
      which regularizes kappa_{1:N} towards zero. Default None is disabled.
    kappa_zero_after: int, clamps kappa_{1:N} to zero after ``kappa_zero_after``
      steps. Default None is disabled.
    device_type: string, used to clamp to a particular hardware device. Default
      None uses JAX default device placement

  Returns:
    A pair. The first element is a float, the final dual loss, which forms a
    valid upper bound on the objective specified by ``verif_instance``. The
    second element is a dict containing various debug info.
  """
  assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type'
  assert isinstance(verif_instance, utils.SdpDualVerifInstance), 'invalid type'

  key = key if key is not None else jax.random.PRNGKey(0)
  opt = opt if opt is not None else optax.adam(1e3)
  dual_vars = jax.tree_map(
      lambda s: None if s is None else jnp.zeros(s), verif_instance.dual_shapes)
  dual_vars = init_duals_ibp(verif_instance, dual_vars)

  # Define loss function
  def loss(dual_vars, exact=use_exact_eig_train):
    return _loss(dual_vars, exact)

  @functools.partial(jax.jit, static_argnums=(1,), backend=device_type)
  def _loss(dual_var, exact):
    loss_val, step_info = dual_fun(
        verif_instance, dual_var, key, n_iter=n_iter_lanczos, exact=exact,
        include_info=True)
    step_info['loss_val'] = loss_val
    return loss_val, step_info

  # Define a compiled update step
  grad = jax.jit(jax.grad(loss, has_aux=True), backend=device_type)

  @functools.partial(jax.jit, backend=device_type)
  def grad_step(params, opt_state):
    g, info = grad(params)
    updates, new_opt_state = opt.update(g, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, info

  # Optimize parameters in a loop
  opt_state = opt.init(dual_vars)
  info = collections.defaultdict(list)
  loss_log = []
  best_loss = 1e9

  # Main loop
  for i in range(num_steps):
    dual_vars, opt_state, step_info = grad_step(dual_vars, opt_state)
    loss_val = step_info['loss_val']
    print(f'Iter {i}: Loss {loss_val}')
    best_loss = min(best_loss, loss_val)
    loss_log.append(loss_val)

    # Regularization of kappa
    if kappa_reg_weight is not None and kappa_reg_weight >= 0:
      onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1])
      mask = jnp.ones_like(onehot) - onehot
      dual_vars[-1] -= mask * kappa_reg_weight
    if (kappa_zero_after is not None and kappa_zero_after >= 0 and
        i > kappa_zero_after):
      onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1])
      dual_vars[-1] *= onehot

    dual_vars = project_duals(dual_vars, verif_instance.dual_types)

    if i % eval_every == 0:
      dual_val, _ = loss(dual_vars, exact=use_exact_eig_eval)
      info['steps'].append(i)
      info['loss_vals'].append(float(dual_val))
      if verbose:
        print(f'Dual iter {i}: Train loss: {loss_val} Loss {dual_val}')

  final_loss = float(loss(dual_vars, exact=use_exact_eig_eval)[0])
  info['final_dual_vars'] = dual_vars
  info['final_loss'] = final_loss
  info['loss_log'] = loss_log
  info['best_train_loss'] = best_loss
  return final_loss, info
Beispiel #16
0
 def test_grad_and_aux_basic(self):
     g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
     self.assertAllClose(g, grad(lambda x: x**3)(3.), check_dtypes=True)
     self.assertAllClose(aux, [9.], check_dtypes=True)
Beispiel #17
0
def solve_sdp_dual(verif_instance, key=None, opt=None, num_steps=10000,
                   verbose=False, eval_every=1000, use_exact_eig_eval=True,
                   use_exact_eig_train=False, n_iter_lanczos=30, scl=-1.0,
                   lr_init=1e-3, steps_per_anneal=100, anneal_factor=1.0,
                   num_anneals=3, opt_name='adam', gd_momentum=0.9,
                   add_diagnostic_stats=False,
                   opt_multiplier_fn=None, init_dual_vars=None,
                   init_opt_state=None, opt_dual_vars=None,
                   kappa_reg_weight=None, kappa_zero_after=None,
                   device_type=None, save_best_k=1):
  # pylint: disable=g-doc-return-or-yield, g-doc-args
  """Compute verified lower bound via dual of SDP relaxation.

  NOTE: This method exposes many hyperparameter options, and the method
  signature is subject to change. We instead suggest using
  ``solve_sdp_dual_simple`` instead if you need a stable interface.
  """
  # NB: Whereas the rest of the code in this library is fairly top-down
  # readable, avoids excessive `if` statements, tries to make the code look
  # like the formalism, etc, this is not the case for this method.
  # This is essentially the outer loop, and includes all the debugging/logging/
  # optimization tricks we need to get/debug good results.
  #
  # NB: Time profiling: On toy VerifInstances, JIT compilation dominates time
  # cost: JIT compilation takes ~12s, then we do ~3000 steps/sec.
  assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type'
  assert isinstance(verif_instance, utils.SdpDualVerifInstance), 'invalid type'
  key = key if key is not None else jax.random.PRNGKey(0)

  dual_vars = jax.tree_map(
      lambda s: None if s is None else jnp.zeros(s), verif_instance.dual_shapes)
  dual_vars = init_duals_ibp(verif_instance, dual_vars)

  if init_dual_vars is not None:
    # Casting, here for Colab. Essentially same as `dual_vars = init_dual_vars`
    dual_vars = utils.structure_like(init_dual_vars, dual_vars)
  if opt_dual_vars is not None:
    opt_dual_vars = utils.structure_like(opt_dual_vars, dual_vars)

  # Create optimizer
  if opt is None:
    if (isinstance(steps_per_anneal, float) or
        isinstance(steps_per_anneal, int)):
      anneal_steps = [steps_per_anneal*(i+1) for i in range(num_anneals)]
    else:
      anneal_steps = np.cumsum(steps_per_anneal)
    anneal_steps = jnp.array(anneal_steps)
    def lr_schedule(t):
      cur_epoch = jnp.minimum(num_anneals, jnp.sum(t > anneal_steps))
      return lr_init * jnp.float_power(anneal_factor, cur_epoch)
    opt_class = getattr(optax, opt_name)
    base_opt = (opt_class(1., momentum=gd_momentum) if opt_name == 'sgd' else
                opt_class(1.))
    opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule))
    if opt_multiplier_fn:
      # NB: Interface very specific to tree.map_structure_with_path
      # Example: opt_multiplier_fn=lambda path: 0.1 if 'lam' in path else 1.0
      opt_multipliers = tree.map_structure_with_path(
          lambda path, v: opt_multiplier_fn(path), dual_vars)
      opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule),
                        utils.scale_by_variable_opt(opt_multipliers))
    else:
      opt = optax.chain(base_opt, optax.scale_by_schedule(lr_schedule))

  # Define loss function
  def loss(dual_vars, loss_scl=scl, exact=use_exact_eig_train):
    return _loss(dual_vars, loss_scl, exact)

  @functools.partial(jax.jit, static_argnums=(1, 2), backend=device_type)
  def _loss(dual_var, loss_scl, exact):
    loss_val, step_info = dual_fun(
        verif_instance, dual_var, key, n_iter=n_iter_lanczos, exact=exact,
        scl=loss_scl, include_info=True)
    step_info['loss_val'] = loss_val
    return loss_val, step_info

  # Define a compiled update step
  grad = jax.jit(jax.grad(loss, has_aux=True), backend=device_type)

  @functools.partial(jax.jit, backend=device_type)
  def grad_step(params, opt_state):
    g, info = grad(params)
    updates, new_opt_state = opt.update(g, opt_state)
    new_params = optax.apply_updates(params, updates)
    info['g'] = g
    info['updates'] = updates
    return new_params, new_opt_state, info

  # Optimize parameters in a loop
  opt_state = opt.init(dual_vars)
  if init_opt_state:
    opt_state = utils.structure_like(init_opt_state, opt_state)
  info = collections.defaultdict(list)
  loss_log = []
  store_best = []
  recent_eig_vecs = collections.deque(maxlen=10)
  best_loss = 1e9
  last_H = None
  start_i = 0

  # Main loop
  for i in range(start_i, num_steps):
    dual_vars_prev = dual_vars
    dual_vars, opt_state, step_info = grad_step(dual_vars, opt_state)
    loss_val = step_info['loss_val']
    print(f'Iter {i}: Loss {loss_val}')
    best_loss = min(best_loss, loss_val)
    if add_diagnostic_stats:
      info['dual_vars'].append(dual_vars_prev)
      eig_vec = step_info['eig_vec']
      cosine_sims = []
      for prev_eig_vec in recent_eig_vecs:
        denom = jnp.sqrt(jnp.linalg.norm(eig_vec)*jnp.linalg.norm(prev_eig_vec))
        eig_sim = jnp.sum(prev_eig_vec * eig_vec) / denom
        cosine_sims.append(abs(float(eig_sim)))
      info['c_lambda'].append(float(step_info['c_lambda']))
      info['past_10_cosine_sims'].append(np.array(cosine_sims))
      info['g'].append(step_info['g'])
      info['updates'].append(step_info['updates'])
      if use_exact_eig_train:
        # The info is for -H, so to get smallest for H, take negative of max
        eig_vals = -step_info['eig_info'][0][-1:-20:-1]
        cur_H = step_info['eig_info'][2]
        diff_H = 0 if last_H is None else np.linalg.norm(cur_H - last_H)
        last_H = cur_H
        info['diff_H'].append(float(diff_H))
        info['smallest_20_eig_vals'].append(eig_vals)
      recent_eig_vecs.appendleft(eig_vec)

    loss_log.append(loss_val)
    if len(store_best) < save_best_k:
      store_best.append((loss_val, dual_vars_prev))
      store_best.sort(key=lambda x: x[0])
    elif loss_val < store_best[-1][0]:
      store_best[-1] = (loss_val, dual_vars_prev)
      store_best.sort(key=lambda x: x[0])

    # Regularization of kappa
    if kappa_reg_weight is not None and kappa_reg_weight >= 0:
      onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1])
      mask = jnp.ones_like(onehot) - onehot
      dual_vars[-1] -= mask * kappa_reg_weight
    if (kappa_zero_after is not None and kappa_zero_after >= 0 and
        i > kappa_zero_after):
      onehot = jax.nn.one_hot([0], dual_vars[-1].shape[1])
      dual_vars[-1] *= onehot

    dual_vars = project_duals(dual_vars, verif_instance.dual_types)

    if opt_dual_vars:
      distance_to_opt = jax.tree_multimap(lambda x, y: jnp.linalg.norm(x - y),
                                          dual_vars, opt_dual_vars)
      info['distance_to_opt'].append(distance_to_opt)

    if i % eval_every == 0:
      dual_val, _ = loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval)
      info['steps'].append(i)
      info['loss_vals'].append(float(dual_val))
      if verbose:
        print(f'Dual iter {i}: Train loss: {loss_val} Loss {dual_val}')

  final_loss = float(loss(dual_vars, loss_scl=-1, exact=use_exact_eig_eval)[0])
  info['final_dual_vars'] = dual_vars
  info['final_opt_state'] = opt_state
  info['final_loss'] = final_loss
  info['loss_log'] = loss_log
  info['store_best'] = store_best

  return final_loss, info
Beispiel #18
0
 def f(x):
     g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
     return aux[0] * np.sin(x)
Beispiel #19
0
def dual_fun(verif_instance, dual_vars, key=None, n_iter=30, scl=-1,
             exact=False, dynamic_unroll=True, include_info=False):
  # pylint: disable=invalid-name
  """Returns the dual objective value.

  Args:
    verif_instance: a utils.SdpDualVerifInstance, the verification problem
    dual_vars: A list of dual variables at each layer
    key: PRNGKey passed to Lanczos
    n_iter: Number of Lanczos iterations to use
    scl: Inverse temperature in softmax over eigenvalues to smooth optimization
        problem (if negative treat as hardmax)
    exact: Whether to use exact eigendecomposition instead of Lanczos
    dynamic_unroll: bool. Whether to use jax.fori_loop for Lanczos for faster
      JIT compilation. Default is False.
    include_info: if True, also return an `info` dict of various other
      values computed for the objective

  Returns:
    Either a single float, the dual upper bound, or if ``include_info=True``,
    returns a pair, the dual bound and a dict containing debugging info
  """
  key = key if key is not None else jax.random.PRNGKey(0)
  assert isinstance(verif_instance, utils.SdpDualVerifInstance)
  bounds = verif_instance.bounds
  layer_sizes = utils.layer_sizes_from_bounds(bounds)
  layer_sizes_1d = [np.prod(np.array(i)) for i in layer_sizes]
  N = sum(layer_sizes_1d) + 1
  info = {}

  # Mean activations at each layer
  activations_center = [(b.lb + b.ub) / 2 for b in bounds]
  # Maximum deviation from mean activations
  radius = [(b.ub - b.lb) / 2 for b in bounds]

  inner_lagrangian = verif_instance.make_inner_lagrangian(dual_vars)
  lagrangian = _make_transformed_lagrangian(
      inner_lagrangian, activations_center, radius)

  # Construct c_lambda and g_lambda terms
  zeros = [jnp.zeros(sz) for sz in layer_sizes]
  c_lambda = lagrangian(zeros)
  g_lambda = jax.grad(lagrangian)(zeros)
  g_lambda = flatten(g_lambda)
  info['c_lambda'] = c_lambda

  def Hv(v):
    """Hessian-vector product for H_lambda - refer to docstring for `Av()`."""
    lag_grad = lambda v2: flatten(jax.grad(lagrangian)(v2))
    hv_v = jax.grad(lambda v2: jnp.vdot(lag_grad(v2), v))(zeros)
    hv_flat = flatten(hv_v)
    return hv_flat

  def Av(v):
    """Matrix-vector product.

    Args:
      v: vector, DeviceArray

    Returns:
      Av: vector, Device array. A is defined as diag(kappa) - M(lambda) where
          M(lambda) = [0, g_lambda';
                       g_lambda, H_lambda], and these terms correspond to
          L~(z) = c_lambda + g_lambda' z + z' H_lambda z
    """
    # Expand Mv=[0 g'; g H] [v0;v1] = [g'v1; v0*g + H(v1)] = [Mv0;Mv1]
    # Compute Mv0 term
    mv_zero = jnp.reshape(jnp.vdot(g_lambda, v[1:]), (1,))
    # Compute Mv1 term
    mv_rest = Hv(v[1:]) + v[0] * g_lambda
    mv = jnp.concatenate([mv_zero, mv_rest], axis=0)
    diag_kappa_v = jnp.reshape(dual_vars[-1], mv.shape) * v
    av = diag_kappa_v - mv
    return jnp.reshape(av, v.shape)

  # Construct dual function (dual_vars[-1]=kappa)
  if exact:
    eig_vec, eig_info = eigenvector_utils.min_eigenvector_exact(
        Av, N, scl=scl, report_all=True)
    info['eig_info'] = eig_info
  else:
    eig_vec = eigenvector_utils.min_eigenvector_lanczos(
        Av, N, min(N, n_iter), key, scl, dynamic_unroll=dynamic_unroll)
  info['eig_vec'] = eig_vec
  info['kappa'] = dual_vars[-1]
  hess_val = jnp.vdot(eig_vec, Av(eig_vec))/(jnp.vdot(eig_vec, eig_vec))
  hess_val = jnp.reshape(hess_val, ())

  # Form dual objective
  lambda_minus = jnp.minimum(hess_val, 0.)
  kappa_hat = jnp.maximum(0, dual_vars[-1] - lambda_minus)
  dual_val = c_lambda + 0.5 * jnp.sum(kappa_hat)
  if include_info:
    return dual_val, info
  return dual_val
 def update_derivative(i, opt_state, batch, l2reg):
     params = get_params(opt_state)
     return opt_update(i,
                       jax.grad(loss, 0)(params, batch, l2reg),
                       opt_state), params
Beispiel #21
0
 def Hv(v):
   """Hessian-vector product for H_lambda - refer to docstring for `Av()`."""
   lag_grad = lambda v2: flatten(jax.grad(lagrangian)(v2))
   hv_v = jax.grad(lambda v2: jnp.vdot(lag_grad(v2), v))(zeros)
   hv_flat = flatten(hv_v)
   return hv_flat
Beispiel #22
0
def gradients(scalar, variables):
    """Compute the gradients of a scalar w.r.t to a given list of variables.

    Arguments
    ---------
    scalar: :class:`symjax.tensor.base.Tensor`
        the variable to differentiate

    variables: List or Tuple
        the variables used to compute the derivative.

    Returns
    -------

        gradients: Tuple
            the sequency of gradients ordered as given in the input variables

    Example
    -------

    .. doctest::

        >>> import symjax
        >>> w = symjax.tensor.ones(3)
        >>> x = symjax.tensor.Variable(2., name='x', dtype='float32')
        >>> l = (w ** 2).sum() * x
        >>> g = symjax.gradients(l, [w])[0]
        >>> f = symjax.function(outputs=g, updates={x:x + 1})
        >>> for i in range(2):
        ...    print(f())
        [4. 4. 4.]
        [6. 6. 6.]

    """
    if numpy.prod(scalar.shape) != 1:
        raise RuntimeError("the variable to differentiate is not a scalar")
    if not isinstance(scalar, t.Tensor):
        raise RuntimeError(
            "the variable used in gradients should be a Tensor type")

    if scalar.shape != ():
        scalar = scalar.sum()
    if isinstance(variables, t.Tensor):
        input_variables = [variables]
        input_list = False
    else:
        input_variables = variables
        input_list = True

    # get the argnum of the variables that we differentiate one
    argnums = list(range(len(input_variables)))

    # get all the roots of the scalar, this is needed as otherwise they are not
    # as the input of the gradient function and thus a change of
    # their value will not change the gradient computation, we also ensure
    # uniqueness
    input_variables += [
        i for i in current_graph().roots(scalar) if i not in input_variables
    ]

    # create a dummy function that is needed for jax to compute a gradient func
    # this function is the one that builds the graph of computation from all
    # roots
    # to the scalar varible s.t. automatic diffenrentiation can be applied

    def fn(*args):
        return current_graph().get(scalar,
                                   dict(zip(input_variables, list(args))))

    # now we obtain the grad function. In fact, Jax returns a function that,
    # when it is called, returns the gradient values, this function is then
    # used to generate the Tuple of symbolic variables
    grad_fn = jax.grad(fn, argnums)
    wrap_fn = t.jax_wrap(grad_fn)
    if input_list:
        return wrap_fn(*input_variables)
    else:
        return wrap_fn(*input_variables)[0]
Beispiel #23
0

def id_func(x):
    return lambda u: jnp.dot(jnp.identity(x.shape[0]), u)


def hvp(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]


def breg_bound(vec, lb=-1.0, ub=1.0, *args, **kwargs):
    return jnp.sum((-vec + ub) * jnp.log(-vec + ub) +
                   (vec - lb) * jnp.log(vec - lb))


DP_bound = jax.grad(breg_bound, 0)


def DP_inv_bound(vec, lb=-1.0, ub=1.0):
    return (ub * jnp.exp(vec) + lb) / (1 + jnp.exp(vec))


def D2P_bound(vec, lb=-1.0, ub=1.0):
    def out(u):
        return jvp(lambda x: DP_bound(x, lb, ub), (vec, ), (u, ))[1]

    return out


def inv_D2P_bound(vec, lb=-1.0, ub=1.0):
    if len(jnp.shape(vec)) <= 1:
Beispiel #24
0
    print("The input is", val)
    print("The result of Rosenbrock's is ", result)


fun_driver(10)
"""### 3. First look at derivatives: `jax.grad()`
1. https://jax.readthedocs.io/en/latest/jax.html#jax.grad
2. Computes $v\cdot J$ for a function that computes a scalar value ($R^n \rightarrow R$).
3. The seed is internally set to `1.0`
   (the shape of the seed must match the primal output).   
"""

#Create a function that computes the derivatives.
#This needs to happen only once.
grad_rosenbrock = jax.grad(rosenbrock)


def grad_driver(n):
    """
    Input: n array length
    Output: Result of Rosenbrock's banana function
    """
    #create the input array
    val = jnp.full(n, 0.5)

    #compute the derivatives
    result = grad_rosenbrock(val)

    plot_vals(val, grad=result)
    print("The input is", val)
Beispiel #25
0
def hvp(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]
import jax.random as random
key = random.PRNGKey(0)

if __name__ == '__main__':

    # 4. Minimize for a few steps
    initial_radii = np.ones(len(all_elements)) * 0.12
    initial_scales = np.ones(len(all_elements)) * 0.85

    theta0 = pack(initial_radii, initial_scales)
    print('initial theta', theta0)
    initial_log_prob = log_prob(theta0)
    print('initial log prob', initial_log_prob)

    # TODO: potentially jit() this
    grad_log_prob = grad(log_prob)
    #grad_log_prob = parallel_grad_log_prob

    print('initial gradient norm = {}'.format(
        np.linalg.norm(grad_log_prob(theta0))))

    minimized_theta_fname = os.path.join(
        data_path, 'elemental_types_l-bfgs_freesolv_{}.npy'.format(name))

    print('minimizing...')
    from scipy.optimize import minimize

    loss = lambda theta: -log_prob(theta)
    grad_loss = lambda theta: -grad_log_prob(theta)
    bounds = [(0.001, 2.0)] * len(theta0)
    result = minimize(loss,
Beispiel #27
0
 def test_grad_tuple_output(self):
     jtu.check_raises(
         lambda: grad(lambda x: (x, x))(1.0), TypeError,
         "Gradient only defined for scalar-output functions. ")
    def test_system_derivatives(self):
        sdf_file = open("examples/host-acd.mol2").read()
        smirnoff = ForceField("test_forcefields/smirnoff99Frosst.offxml")
        mol = Chem.MolFromMol2Block(sdf_file,
                                    sanitize=True,
                                    removeHs=False,
                                    cleanupSubstructures=True)

        guest_potentials, guest_params, guest_param_groups, guest_conf, guest_masses = forcefield.parameterize(
            mol, smirnoff)

        ref_nrgs = []
        test_nrgs = []

        for potential, params in guest_potentials:
            jax_potential = potential_map[potential]
            if potential == timemachine.lib.custom_ops.HarmonicBond_f64:
                jp = functools.partial(jax_potential,
                                       box=None,
                                       bond_idxs=params[0],
                                       param_idxs=params[1])
            elif potential == timemachine.lib.custom_ops.HarmonicAngle_f64:
                jp = functools.partial(jax_potential,
                                       box=None,
                                       angle_idxs=params[0],
                                       param_idxs=params[1])
            elif potential == timemachine.lib.custom_ops.PeriodicTorsion_f64:
                jp = functools.partial(jax_potential,
                                       box=None,
                                       torsion_idxs=params[0],
                                       param_idxs=params[1])
            elif potential == timemachine.lib.custom_ops.LennardJones_f64:
                jp = functools.partial(jax_potential,
                                       box=None,
                                       scale_matrix=params[0],
                                       param_idxs=params[1])
            elif potential == timemachine.lib.custom_ops.Electrostatics_f64:
                jp = functools.partial(jax_potential,
                                       box=None,
                                       scale_matrix=params[0],
                                       param_idxs=params[1])
            else:
                raise ValueError("unknown functional form")

            test_nrgs.append(potential(*params))
            ref_nrgs.append(jp)

        def ref_total_nrg(conf, params):
            nrgs = []
            for p in ref_nrgs:
                nrgs.append(p(conf, params))
            return jnp.sum(nrgs)

        dp_idxs = onp.arange(len(params)).astype(onp.int32)

        def test_total_nrg(conf, params):
            nrgs = []
            for p in test_nrgs:
                res = p.derivatives(onp.expand_dims(conf, axis=0), params,
                                    dp_idxs)
                nrgs.append(res[0])
            return onp.sum(nrgs)

        num_atoms = guest_conf.shape[0]

        ref_e = ref_total_nrg(guest_conf, guest_params)
        test_e = test_total_nrg(guest_conf, guest_params)

        onp.testing.assert_almost_equal(ref_e, test_e)

        dt = 1e-3
        ca = 0.5
        cb = onp.random.rand(num_atoms) / 10
        cc = onp.zeros(num_atoms)

        intg = ReferenceLangevin(dt, ca, cb, cc)

        ref_dE_dx_fn = jax.grad(ref_total_nrg, argnums=(0, ))
        ref_dE_dx_fn = jax.jit(ref_dE_dx_fn)

        def integrate(x_t, v_t, params):
            for _ in range(100):
                x_t, v_t = intg.step(x_t, v_t, ref_dE_dx_fn(x_t, params)[0])
            return x_t, v_t

        v0 = onp.random.rand(num_atoms * 3).reshape(num_atoms, 3)

        x_f, v_f = integrate(guest_conf, v0, guest_params)

        lo = custom_ops.LangevinOptimizer_f64(dt, ca, cb, cc)

        ctxt = custom_ops.Context_f64(test_nrgs, lo, guest_params, guest_conf,
                                      v0, dp_idxs)

        for _ in range(100):
            ctxt.step()

        onp.testing.assert_almost_equal(x_f, ctxt.get_x())
        onp.testing.assert_almost_equal(v_f, ctxt.get_v())

        grad_fn = jax.jacfwd(integrate, argnums=(2))
        dx_dp_f, dv_dp_f = grad_fn(guest_conf, v0, guest_params)
        dx_dp_f = onp.asarray(onp.transpose(dx_dp_f, (2, 0, 1)))
        dv_dp_f = onp.asarray(onp.transpose(dv_dp_f, (2, 0, 1)))

        onp.testing.assert_almost_equal(dx_dp_f[dp_idxs], ctxt.get_dx_dp())
        onp.testing.assert_almost_equal(dv_dp_f[dp_idxs], ctxt.get_dv_dp())
Beispiel #29
0
 def test_grad_nonscalar_output(self):
     jtu.check_raises(
         lambda: grad(lambda x: x)(onp.zeros(3)), TypeError,
         "Gradient only defined for scalar-output functions. ")
Beispiel #30
0
import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial

def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(lambda params, inputs, targets:  # fast per-example gradients
                  vmap(partial(grad_fun, params), inputs, targets))
Beispiel #31
0
 def test_holomorphic_grad(self):
     out = grad(lambda x: np.sin(x), holomorphic=True)(1 + 2j)
     expected = 2.0327230070196656 - 3.0518977991518j
     self.assertAllClose(out, expected, check_dtypes=False)