Example #1
0
def _solve(a, b, sym_pos, lower):
    if not sym_pos:
        return np_linalg.solve(a, b)

    a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
    lax_linalg._check_solve_shapes(a, b)

    # With custom_linear_solve, we can reuse the same factorization when
    # computing sensitivities. This is considerably faster.
    factors = cho_factor(lax.stop_gradient(a), lower=lower)
    custom_solve = partial(lax.custom_linear_solve,
                           lambda x: lax_linalg._matvec_multiply(a, x),
                           solve=lambda _, x: cho_solve(factors, x),
                           symmetric=True)
    if a.ndim == b.ndim + 1:
        # b.shape == [..., m]
        return custom_solve(b)
    else:
        # b.shape == [..., m, k]
        return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def unsigned_int_scale(x, *, prec, axis=None):
    """Computes a scale s, s.t.

  0 <= s * x <= 2**prec, where min(x) >= 0.
  Does not propagate gradients.

  Args:
    x: The input to be scaled.
    prec: Unsigned int precision of the scaled result.
    axis: Dimensions of input to consider for scaling.

  Returns:
    The scaling value.
  """
    max_x = jnp.max(x, axis=axis, keepdims=True)
    if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
        max_x += jnp.finfo(jnp.float32).eps  # to avoid div by 0

    scale = unsigned_int_bound(prec) / max_x
    scale = lax.stop_gradient(scale)
    return scale
def signed_int_scale(x, *, prec, axis=None):
    """Computes a scale s, s.t.

  -2**(prec - 1) + 1 <= s * x <= 2**(prec - 1) - 1.

  Does not propagate gradients.

  Args:
    x: The input to be scaled.
    prec: Signed int precision of the scaled result.
    axis: Dimensions of input to consider for scaling.

  Returns:
    The scaling value.
  """
    abs_max_x = max_abs_weights(x, axis=axis)
    if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
        abs_max_x += jnp.finfo(jnp.float32).eps  # to avoid div by 0
    scale = signed_int_bound(prec) / abs_max_x
    scale = lax.stop_gradient(scale)
    return scale
Example #4
0
def SBL(
        X,
        y,
        prior_init=None,
        hyper_prior=((1e-6, 1e-6), (1e-6, 1e-6)),
        tol=1e-5,
        max_iter=1000,
):

    n_samples, n_features = X.shape
    if prior_init is None:
        prior_init = jnp.ones((n_features + 1, ))
        prior_init = jax.ops.index_update(
            prior_init, n_features,
            1 / (jnp.var(y) + 1e-6))  # setting initial noise value

    # Adding term for gradient of loss
    prior_init = jnp.concatenate(
        [prior_init, jnp.ones((n_features, ))], axis=0)

    gram = jnp.dot(X.T, X)
    XT_y = jnp.dot(X.T, y)

    prior_params, iterations = fixed_point_solver(
        update,
        (X, y, gram, XT_y, hyper_prior),
        prior_init,
        lambda z_prev, z: jnp.linalg.norm(z_prev[-n_features:] - z[-n_features:
                                                                   ]) > tol,
        max_iter=max_iter,
    )

    prior = stop_gradient(prior_params)
    loss, mn = evidence(X, y, gram, XT_y, prior[:-n_features], hyper_prior)
    metrics = (
        iterations,
        jnp.linalg.norm(mn.squeeze() - prior[-n_features:]),
        prior[-n_features:],
    )
    return loss, mn, prior[:-n_features], metrics
Example #5
0
  def _compute_loss_and_stats(params, model_out, use_elastic_loss=False):
    rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean()
    stats = {
        'loss/rgb': rgb_loss,
    }
    loss = rgb_loss
    if use_elastic_loss:
      v_elastic_fn = jax.jit(vmap(vmap(compute_elastic_loss)))
      weights = lax.stop_gradient(model_out['weights'])
      jacobian = model_out['warp_jacobian']
      # Pick the median point Jacobian.
      if elastic_reduce_method == 'median':
        depth_indices = model_utils.compute_depth_index(weights)
        jacobian = jnp.take_along_axis(
            # Unsqueeze axes: sample axis, Jacobian row, Jacobian col.
            jacobian, depth_indices[..., None, None, None], axis=-3)
      # Compute loss using Jacobian.
      elastic_loss, elastic_residual = v_elastic_fn(jacobian)
      # Multiply weight if weighting by density.
      if elastic_reduce_method == 'weight':
        elastic_loss = weights * elastic_loss
      elastic_loss = elastic_loss.sum(axis=-1).mean()
      stats['loss/elastic'] = elastic_loss
      stats['residual/elastic'] = jnp.mean(elastic_residual)
      loss += scalar_params.elastic_loss_weight * elastic_loss

    if 'warp_jacobian' in model_out:
      jacobian = model_out['warp_jacobian']
      jacobian_det = jnp.linalg.det(jacobian)
      jacobian_div = utils.jacobian_to_div(jacobian)
      jacobian_curl = utils.jacobian_to_curl(jacobian)
      stats['metric/jacobian_det'] = jnp.mean(jacobian_det)
      stats['metric/jacobian_div'] = jnp.mean(jacobian_div)
      stats['metric/jacobian_curl'] = jnp.mean(
          jnp.linalg.norm(jacobian_curl, axis=-1))

    stats['loss/total'] = loss
    stats['metric/psnr'] = utils.compute_psnr(rgb_loss)
    return loss, stats
Example #6
0
def loss_fn_pinn_bayes_mse_hyperprior(params,
                                      state,
                                      model,
                                      x,
                                      y,
                                      prior_params_mse=(0.0, 0.0)):

    variables = {"params": params, **state}
    (prediction, dt, theta,
     coeffs), updated_state = model.apply(variables,
                                          x,
                                          mutable=list(state.keys()))
    n_samples, n_features = theta.shape
    # Calculating precision of mse
    tau = precision(y, prediction, *prior_params_mse)
    p_mse, MSE = normal_LL(prediction, y, tau)
    p_mse += gamma_LL(tau, *prior_params_mse)  # adding prior

    # Calculating precision of reg

    hyper_prior_nu = (n_samples / 2, n_samples / stop_gradient(tau))
    nu = precision(dt, theta @ coeffs, *hyper_prior_nu)
    # calculates nu given gamma prior
    p_reg, reg = normal_LL(dt, theta @ coeffs, nu)
    p_reg += gamma_LL(nu, *hyper_prior_nu)  # adding priorr

    loss = -(p_mse + p_reg)

    metrics = {
        "loss": loss,
        "p_mse": p_mse,
        "mse": MSE,
        "p_reg": p_reg,
        "reg": reg,
        "coeff": coeffs,
        "tau": tau,
        "nu": nu,
    }
    return loss, (updated_state, metrics, (prediction, dt, theta, coeffs))
Example #7
0
    def create_positive(cls, *, bounds, prec):
        """Create QuantOps for positive activations clipped to [0, bounds].

    Args:
      bounds: The upper bound to clip the activations.
      prec: Unsigned int precision for the QuantOps.

    Returns:
      QuantOps for quantizing/dequantizing unsigned activations.
    """
        initial_bounds = bounds
        bounds = jnp.asarray(bounds, SCALE_DTYPE)
        if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
            bounds += jnp.finfo(SCALE_DTYPE).eps  # to avoid div by 0
        scale = primitives.unsigned_int_bound(prec=prec) / bounds
        # NOTE: stop_gradient is needed here to prevent gradient flow through scale
        # when scale is not a constant, but computed as a function of activations.
        scale = lax.stop_gradient(scale)
        return cls(prec=prec,
                   scale=scale,
                   symmetric=False,
                   bounds=initial_bounds)
Example #8
0
def _solve(a, b):
  _check_solve_shapes(a, b)

  # Broadcast leading dimensions of b to the shape of a, as is required by
  # custom_linear_solve.
  out_shape = tuple(d_a if d_b == 1 else d_b
                    for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
  b = jnp.broadcast_to(b, out_shape)

  # With custom_linear_solve, we can reuse the same factorization when
  # computing sensitivities. This is considerably faster.
  lu_, _, permutation = lu(lax.stop_gradient(a))
  custom_solve = partial(
      lax.custom_linear_solve,
      lambda x: _matvec_multiply(a, x),
      solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
      transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
  if a.ndim == b.ndim + 1:
    # b.shape == [..., m]
    return custom_solve(b)
  else:
    # b.shape == [..., m, k]
    return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
Example #9
0
        def elbo_samples(samples, stop_grads=False):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD

            if stop_grads:
                log_posterior = np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(samples,
                                       lax.stop_gradient(posterior_params)),
                                       axis=(1, 2))
            else:
                log_posterior = np.sum(vmap(bernoulli.logpmf,
                                            in_axes=(0,
                                                     None))(samples,
                                                            posterior_params),
                                       axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples
Example #10
0
def Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None):
    """
    Performs the Allreduce operation `op` on the input `x` using the
    communicator `comm` which defaults to the world comunicator.
    An optional token can be passed, which is used to force jax to execute
    MPI operations in the correct order.

    Argumemnts:
        x: Array or scalar input.
        op: The reduction operation `MPI.Op` (e.g: MPI.SUM)
        comm: The communicator (defaults to MPI.COMM_WORLD)
        token: token to force a sequential order in the operations (default=None)

    Returns:
        res: result of the allreduce operation
        new_token: a new, modified token, that depends on this operation.
            This result can be ignored if result forces a data dependency.
    """
    if token is None:
        token = create_token(stop_gradient(x))

    op = wrap_as_hashable(op)
    comm = wrap_as_hashable(comm)
    return mpi_allreduce_p.bind(x, token, op=op, comm=comm)
Example #11
0
    def create_symmetric_fp(
        cls,
        *,
        bounds,
        fp_quant,
    ):
        """Create QuantOps for symmetric clipping to floating-point bounds.

    Args:
      bounds: The upper (and absolute lower) bound to clip the inputs.
      fp_quant: quantization floating-point specification of the target format.

    Returns:
      QuantOps for quantizing/dequantizing signed activations.
    """
        if bounds is None:
            if fp_quant.is_scaled:
                raise ValueError(
                    'bounds can only be None if fp_quant.is_scaled is False.')
            return cls(prec=fp_quant, scale=None, symmetric=True, bounds=None)
        else:
            initial_bounds = bounds
            bounds = jnp.asarray(bounds, SCALE_DTYPE)
            if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING:
                bounds += jnp.finfo(SCALE_DTYPE).eps  # to avoid log2(0)
            scale = jnp.exp2(
                -jnp.floor(jnp.log2(bounds)))  # Scale to unit binade.
            # NOTE: stop_gradient is needed here to prevent gradient flow through
            # scale when scale is not a constant, but computed as a function of
            # activations or weights.
            scale = lax.stop_gradient(scale)

            return cls(prec=fp_quant,
                       scale=scale,
                       symmetric=True,
                       bounds=initial_bounds)
Example #12
0
    def score_function_objective(encoder_params,
                                 decoder_params,
                                 batch,
                                 prng_key,
                                 num_samples=1):
        """
        Computes the score function objective of a discrete VAE.
        The gradient of this objective matches the core function gradient of the ELBO.
        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """

        encoder1, encoder2 = encoder
        encoder_params1, encoder_params2 = encoder_params
        decoder1, decoder2 = decoder
        decoder_params1, decoder_params2 = decoder_params

        first_layer_key, second_layer_key = random.split(prng_key)

        # Outer layer latents
        first_layer_params = encoder1(encoder_params1, batch)  # BxD
        first_layer_samples = bernoulli.sample(first_layer_params,
                                               first_layer_key, num_samples)

        # Inner layer latents
        second_layer_params = encoder2(encoder_params2, first_layer_samples)
        second_layer_samples = bernoulli.sample(second_layer_params,
                                                second_layer_key,
                                                num_samples=1)[0, ...]

        # Inner layer prior
        log_prior = np.sum(bernoulli.logpmf(
            second_layer_samples, decoder2(decoder_params2,
                                           first_layer_samples)),
                           axis=(1, 2))  # SxBxD
        # Outer layer prior
        log_prior += np.sum(vmap(bernoulli.logpmf,
                                 in_axes=(0,
                                          None))(first_layer_samples,
                                                 0.5 * np.ones(
                                                     (batch.shape[0], 200))),
                            axis=(1, 2))  # SxBxD

        # Inner layer posterior
        log_posterior = np.sum(bernoulli.logpmf(second_layer_samples,
                                                second_layer_params),
                               axis=(1, 2))
        # Outer layer posterior
        log_posterior += np.sum(vmap(bernoulli.logpmf,
                                     in_axes=(0, None))(first_layer_samples,
                                                        first_layer_params),
                                axis=(1, 2))

        # Likelihood
        log_likelihood = vmap(bernoulli_log_likelihood,
                              in_axes=(None, 0, None))(decoder_params[0],
                                                       first_layer_samples,
                                                       batch)

        elbo_samples = log_likelihood - log_posterior + log_prior
        elbo_samples = lax.stop_gradient(elbo_samples) * log_posterior
        return -np.mean(elbo_samples, axis=0) / batch.shape[0]
Example #13
0
    def relax_plus_rebar_objective(encoder_params,
                                   decoder_params,
                                   surrogate_params,
                                   log_temperature,
                                   batch,
                                   prng_key,
                                   num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param surrogate_params: surrogate parameters (list)
        :param log_temperature: log of inverse temperature
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)

        def concrete(x):
            return jax.nn.sigmoid(x * temperature)

        encoder1, encoder2 = encoder
        encoder_params1, encoder_params2 = encoder_params
        decoder1, decoder2 = decoder
        decoder_params1, decoder_params2 = decoder_params

        # Sampling
        sampling_key, control_variate_key = random.split(prng_key)
        first_layer_sampling_key, second_layer_sampling_key = random.split(
            sampling_key)
        first_layer_control_variate_key, second_layer_control_variate_key = random.split(
            control_variate_key)

        # Outer Layer
        first_layer_params = encoder1(encoder_params1, batch)  # BxD
        first_layer_relaxed_samples = binary_relaxed.sample(
            first_layer_params, first_layer_sampling_key, num_samples)
        first_layer_posterior_samples = np.heaviside(
            first_layer_relaxed_samples, 0)
        first_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            first_layer_params, first_layer_posterior_samples,
            first_layer_control_variate_key)

        # Inner Layer
        second_layer_params = encoder2(encoder_params2,
                                       first_layer_posterior_samples)
        second_layer_relaxed_samples = binary_relaxed.sample(
            second_layer_params, second_layer_sampling_key, num_samples=1)[0,
                                                                           ...]
        second_layer_posterior_samples = np.heaviside(
            second_layer_relaxed_samples, 0)
        second_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            second_layer_params, second_layer_posterior_samples,
            second_layer_control_variate_key)

        # Inner layer posterior
        log_posterior = np.sum(bernoulli.logpmf(second_layer_posterior_samples,
                                                second_layer_params),
                               axis=(1, 2))
        # Outer layer posterior
        log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
            first_layer_posterior_samples, first_layer_params),
                                axis=(1, 2))

        def elbo_samples(first_layer_samples,
                         second_layer_samples,
                         stop_grads=False):
            # Inner layer prior
            log_prior = np.sum(bernoulli.logpmf(
                second_layer_samples,
                decoder2(decoder_params2, first_layer_samples)),
                               axis=(1, 2))  # SxBxD
            # Outer layer prior
            log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))),
                                axis=(1, 2))  # SxBxD

            if stop_grads:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples,
                    lax.stop_gradient(second_layer_params)),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(first_layer_samples,
                                       lax.stop_gradient(first_layer_params)),
                                        axis=(1, 2))
            else:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples, second_layer_params),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf, in_axes=(0, None))(first_layer_samples,
                                                         first_layer_params),
                                        axis=(1, 2))

            # Likelihood
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params1,
                                                           first_layer_samples,
                                                           batch)

            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(first_layer_posterior_samples,
                                       second_layer_posterior_samples)
        unconditional_evaluation = elbo_samples(
            concrete(first_layer_relaxed_samples),
            concrete(second_layer_relaxed_samples))
        conditional_evaluation = elbo_samples(
            concrete(first_layer_conditional_relaxed_samples),
            concrete(second_layer_conditional_relaxed_samples))
        frozen_conditional_evaluation = elbo_samples(
            concrete(
                lax.stop_gradient(first_layer_conditional_relaxed_samples)),
            concrete(
                lax.stop_gradient(second_layer_conditional_relaxed_samples)),
            stop_grads=True)

        relaxed_surrogate_inputs = np.concatenate(
            [first_layer_relaxed_samples, second_layer_relaxed_samples],
            axis=-1)

        conditional_relaxed_surrogate_inputs = np.concatenate([
            first_layer_conditional_relaxed_samples,
            second_layer_conditional_relaxed_samples
        ],
                                                              axis=-1)

        obj = (lax.stop_gradient(elbo_evaluation) -
               frozen_conditional_evaluation -
               np.sum(surrogate(
                   surrogate_params,
                   lax.stop_gradient(conditional_relaxed_surrogate_inputs)),
                      axis=1).squeeze()) * log_posterior

        obj += unconditional_evaluation + np.sum(surrogate(
            surrogate_params, relaxed_surrogate_inputs),
                                                 axis=1).squeeze()
        obj -= conditional_evaluation + np.sum(surrogate(
            surrogate_params, conditional_relaxed_surrogate_inputs),
                                               axis=1).squeeze()

        return -np.mean(obj, axis=0) / batch.shape[0]
Example #14
0
    def arm_objective(encoder_params,
                      decoder_params,
                      batch,
                      prng_key,
                      num_samples=1):
        """
        Computes the ARM objective of a discrete VAE.
        The gradient of this objective matches the core function gradient of the ELBO.
        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        encoder1, encoder2 = encoder
        encoder_params1, encoder_params2 = encoder_params
        decoder1, decoder2 = decoder
        decoder_params1, decoder_params2 = decoder_params

        bernoulli_key, uniform_key1, uniform_key2 = random.split(prng_key,
                                                                 num=3)

        first_layer_params = encoder1(encoder_params1, batch)  # BxD
        first_layer_bernoulli_samples = bernoulli.sample(
            first_layer_params, bernoulli_key, num_samples)
        second_layer_params = encoder2(encoder_params2,
                                       first_layer_bernoulli_samples)

        first_layer_uniform_samples = random.uniform(
            key=uniform_key1, shape=(num_samples, *first_layer_params.shape))
        first_layer_antithetic_samples = 1 - first_layer_uniform_samples

        second_layer_uniform_samples = random.uniform(
            key=uniform_key2, shape=second_layer_params.shape)
        second_layer_antithetic_samples = 1 - second_layer_uniform_samples

        first_layer_uniform_condition = first_layer_uniform_samples <= first_layer_params
        first_layer_antithetic_condition = first_layer_antithetic_samples <= first_layer_params

        second_layer_uniform_condition = second_layer_uniform_samples <= second_layer_params
        second_layer_antithetic_condition = second_layer_antithetic_samples <= second_layer_params

        def elbo_samples(first_layer_samples, second_layer_samples):
            log_prior = np.sum(bernoulli.logpmf(
                second_layer_samples,
                decoder2(decoder_params2, first_layer_bernoulli_samples)),
                               axis=(1, 2))  # SxBxD
            log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))),
                                axis=(1, 2))  # SxBxD

            log_posterior = np.sum(bernoulli.logpmf(second_layer_samples,
                                                    second_layer_params),
                                   axis=(1, 2))
            log_posterior += np.sum(vmap(bernoulli.logpmf,
                                         in_axes=(0,
                                                  None))(first_layer_samples,
                                                         first_layer_params),
                                    axis=(1, 2))

            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params1,
                                                           first_layer_samples,
                                                           batch)

            elbo = log_likelihood - log_posterior + log_prior
            return elbo

        sample_elbo = elbo_samples(first_layer_uniform_condition,
                                   second_layer_uniform_condition)
        antithetic_sample_elbo = elbo_samples(
            first_layer_antithetic_condition,
            second_layer_antithetic_condition)

        loss = lax.stop_gradient(antithetic_sample_elbo - sample_elbo)
        loss = loss * (np.sum(first_layer_params *
                              (first_layer_uniform_samples - 0.5)) +
                       np.sum(second_layer_params *
                              (second_layer_uniform_samples - 0.5)))

        return -np.mean(loss, axis=0) / batch.shape[0]
Example #15
0
    def tunable_rebar_objective(encoder_params,
                                decoder_params,
                                batch,
                                prng_key,
                                cv_coeff=1.0,
                                log_temperature=0.5,
                                num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param cv_coeff: control variate coefficient
        :param log_temperature: log_temperature parameter for rebar control variate.
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)
        encoder1, encoder2 = encoder
        encoder_params1, encoder_params2 = encoder_params
        decoder1, decoder2 = decoder
        decoder_params1, decoder_params2 = decoder_params

        # Sampling
        sampling_key, control_variate_key = random.split(prng_key)
        first_layer_sampling_key, second_layer_sampling_key = random.split(
            sampling_key)
        first_layer_control_variate_key, second_layer_control_variate_key = random.split(
            control_variate_key)

        # Outer Layer
        first_layer_params = encoder1(encoder_params1, batch)  # BxD
        first_layer_relaxed_samples = binary_relaxed.sample(
            first_layer_params, first_layer_sampling_key, num_samples)
        first_layer_posterior_samples = np.heaviside(
            first_layer_relaxed_samples, 0)
        first_layer_concrete_samples = jax.nn.sigmoid(
            first_layer_relaxed_samples * temperature)
        first_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            first_layer_params, first_layer_posterior_samples,
            first_layer_control_variate_key)
        first_layer_cv_samples = jax.nn.sigmoid(
            first_layer_conditional_relaxed_samples * temperature)
        first_layer_frozen_cv_samples = jax.nn.sigmoid(
            lax.stop_gradient(first_layer_conditional_relaxed_samples) *
            temperature)

        # Inner Layer
        second_layer_params = encoder2(encoder_params2,
                                       first_layer_posterior_samples)
        second_layer_relaxed_samples = binary_relaxed.sample(
            second_layer_params, second_layer_sampling_key, num_samples=1)[0,
                                                                           ...]
        second_layer_posterior_samples = np.heaviside(
            second_layer_relaxed_samples, 0)
        second_layer_concrete_samples = jax.nn.sigmoid(
            second_layer_relaxed_samples * temperature)
        second_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            second_layer_params, second_layer_posterior_samples,
            second_layer_control_variate_key)
        second_layer_cv_samples = jax.nn.sigmoid(
            second_layer_conditional_relaxed_samples * temperature)
        second_layer_frozen_cv_samples = jax.nn.sigmoid(
            lax.stop_gradient(second_layer_conditional_relaxed_samples) *
            temperature)

        # Inner layer posterior
        log_posterior = np.sum(bernoulli.logpmf(second_layer_posterior_samples,
                                                second_layer_params),
                               axis=(1, 2))
        # Outer layer posterior
        log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
            first_layer_posterior_samples, first_layer_params),
                                axis=(1, 2))

        def elbo_samples(first_layer_samples,
                         second_layer_samples,
                         stop_grads=False):
            # Inner layer prior
            log_prior = np.sum(bernoulli.logpmf(
                second_layer_samples,
                decoder2(decoder_params2, first_layer_samples)),
                               axis=(1, 2))  # SxBxD
            # Outer layer prior
            log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))),
                                axis=(1, 2))  # SxBxD

            if stop_grads:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples,
                    lax.stop_gradient(second_layer_params)),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(first_layer_samples,
                                       lax.stop_gradient(first_layer_params)),
                                        axis=(1, 2))
            else:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples, second_layer_params),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf, in_axes=(0, None))(first_layer_samples,
                                                         first_layer_params),
                                        axis=(1, 2))

            # Likelihood
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params1,
                                                           first_layer_samples,
                                                           batch)

            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(first_layer_posterior_samples,
                                       second_layer_posterior_samples)
        concrete_evaluation = elbo_samples(first_layer_concrete_samples,
                                           second_layer_concrete_samples)
        cv_evaluation = elbo_samples(first_layer_cv_samples,
                                     second_layer_cv_samples)
        frozen_cv_evaluation = elbo_samples(first_layer_frozen_cv_samples,
                                            second_layer_frozen_cv_samples,
                                            True)

        obj = (lax.stop_gradient(elbo_evaluation) -
               cv_coeff * frozen_cv_evaluation) * log_posterior

        obj += cv_coeff * (concrete_evaluation - cv_evaluation)

        return -np.mean(obj, axis=0) / batch.shape[0]
Example #16
0
 def network(inputs: jnp.ndarray) -> jnp.ndarray:
     """Simple Q-network with randomized prior function."""
     net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
     prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
     x = hk.Flatten()(inputs)
     return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
Example #17
0
 def f(x):
   return lax.sin(x) * lax.cos(lax.stop_gradient(x))
Example #18
0
 def explicit_jacobian_solve(matvec, b):
   return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b))
Example #19
0
def dot_product_attention(query,
                          key,
                          value,
                          dtype=jnp.float32,
                          bias=None,
                          axis=None,
                          broadcast_dropout=True,
                          dropout_rng=None,
                          dropout_rate=0.,
                          deterministic=False,
                          precision=None):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights. This
  function supports multi-dimensional inputs. This version is modified to
  move the softmax division after the dot product.


  Args:
    query: queries for calculating attention with shape of `[batch_size, dim1,
      dim2, ..., dimN, num_heads, mem_channels]`.
    key: keys for calculating attention with shape of `[batch_size, dim1, dim2,
      ..., dimN, num_heads, mem_channels]`.
    value: values to be used in attention with shape of `[batch_size, dim1,
      dim2,..., dimN, num_heads, value_channels]`.
    dtype: the dtype of the computation (default: float32)
    bias: bias for the attention weights. This can be used for incorporating
      autoregressive mask, padding mask, proximity bias.
    axis: axises over which the attention is applied.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.

  Returns:
    Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`.
  """
    assert key.shape[:-1] == value.shape[:-1]
    assert (query.shape[0:1] == key.shape[0:1]
            and query.shape[-1] == key.shape[-1])

    if axis is None:
        axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
        axis = (axis, )
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim
    for ax in axis:
        if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
            raise ValueError('Attention axis must be between the batch '
                             'axis and the last-two axes.')
    depth = query.shape[-1]
    n = key.ndim
    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(np.delete(range(n), axis + (n - 1, )))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    qk_perm = batch_dims + axis + (n - 1, )
    key = key.transpose(qk_perm)
    query = query.transpose(qk_perm)
    # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>)
    v_perm = batch_dims + (n - 1, ) + axis
    value = value.transpose(v_perm)

    query = query / jnp.sqrt(depth).astype(dtype)
    batch_dims_t = tuple(range(len(batch_dims)))
    attn_weights = lax.dot_general(query,
                                   key, (((n - 1, ), (n - 1, )),
                                         (batch_dims_t, batch_dims_t)),
                                   precision=precision)

    # apply attention bias: masking, droput, proximity bias, ect.
    if bias is not None:
        attn_weights = attn_weights + bias

    # normalize the attention weights
    norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim))
    decoding = attn_weights.shape[-2] != 256
    if decoding:
        attn_weights = lax.exp(attn_weights - jax.scipy.special.logsumexp(
            attn_weights, axis=norm_dims, keepdims=True))
    else:
        # move the division by the softmax denominator to after the dot product
        attn_weights = jnp.exp(attn_weights - lax.stop_gradient(
            jnp.max(attn_weights, axis=norm_dims, keepdims=True)))
        softmax_denominator = jnp.sum(attn_weights,
                                      axis=norm_dims,
                                      keepdims=False)
    attn_weights = attn_weights.astype(dtype)

    # apply dropout
    if not deterministic and dropout_rate > 0.:
        if dropout_rng is None:
            dropout_rng = nn.make_rng()
        keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
        if broadcast_dropout:
            # dropout is broadcast across the batch+head+non-attention dimension
            dropout_dims = attn_weights.shape[-(2 * len(axis)):]
            dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims)
            keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = (keep.astype(attn_weights.dtype) /
                      jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # compute the new values given the attention weights
    wv_contracting_dims = (norm_dims, range(value.ndim - len(axis),
                                            value.ndim))
    y = lax.dot_general(attn_weights,
                        value,
                        (wv_contracting_dims, (batch_dims_t, batch_dims_t)),
                        precision=precision)
    if not decoding:
        # divide by the denominator of the attention softmax now, when the array is
        # O(N*H) rather than O(N^2)
        y = y / jnp.expand_dims(softmax_denominator, -1)

    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    y = y.transpose(perm_inv)
    return y
Example #20
0
 def explicit_jacobian_solve_aux(matvec, b):
   x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))
   return x, array_aux
Example #21
0
    def __call__(self, query, train=True, emb_mask=None, padding_mask=None):
        """Quantizes query array using VQ discretization bottleneck."""
        flat_query = jnp.reshape(query, [-1, query.shape[-1]])

        is_initialized = self.has_variable('vqvae', 'counter')
        if not is_initialized:
            self.ema_emb.value = self.vq_emb.value
        embeddings = self.vq_emb.value

        distances = (jnp.sum(flat_query**2, 1, keepdims=True) -
                     2 * jnp.dot(flat_query, embeddings) +
                     jnp.sum(embeddings**2, 0, keepdims=True))
        if emb_mask is not None:  # Mask out some embeddings i.e. pad, BOS, EOS.
            # emb_mask shape == [batch_size, num_embeddings,]
            distances += INF * (1 - emb_mask)

        encoding_indices = jnp.argmin(distances, axis=1)
        encodings = jax.nn.one_hot(encoding_indices,
                                   self.num_embeddings,
                                   dtype=distances.dtype)

        encoding_indices = jnp.reshape(encoding_indices, query.shape[:-1])
        quantized = embeddings.T[encoding_indices]

        e_latent_loss = jnp.mean(
            jnp.square(lax.stop_gradient(quantized) - query),
            axis=-1,
            keepdims=True)

        if train and is_initialized:
            self.counter.value += 1

            dw = jnp.matmul(flat_query.T, encodings)

            decay = lax.convert_element_type(self.decay, dw.dtype)
            # Update ema_cluster_size and ema_emb
            one = jnp.ones([], dw.dtype)
            self.cluster_sizes.value = (self.cluster_sizes.value * self.decay +
                                        jnp.sum(encodings, axis=0) *
                                        (one - decay))
            self.ema_emb.value = self.ema_emb.value * decay + dw * (one -
                                                                    decay)

            # Assign updated ema_emb to emb
            updated_ema_emb = self.ema_emb.value

            n = jnp.sum(self.cluster_sizes.value)
            updated_ema_cluster_size = ((self.cluster_sizes.value + EPS) /
                                        (n + self.num_embeddings * EPS) * n)

            normalised_updated_ema_w = (
                updated_ema_emb /
                jnp.reshape(updated_ema_cluster_size, [1, -1]))
            self.vq_emb.value = normalised_updated_ema_w

        if padding_mask is not None:
            encoding_indices *= padding_mask
            e_latent_loss *= padding_mask[Ellipsis, None]

        loss = self.commitment_cost * e_latent_loss.sum()
        quantized = query + lax.stop_gradient(quantized - query)

        indices = lax.stop_gradient(encoding_indices).astype(jnp.int32)
        return {
            'latents': quantized,
            'loss': loss,
            'latent_indices': indices,
        }
Example #22
0
    def tunable_rebar_objective(encoder_params,
                                decoder_params,
                                batch,
                                prng_key,
                                cv_coeff=1.0,
                                log_temperature=0.5,
                                num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param cv_coeff: control variate coefficient
        :param log_temperature: log_temperature parameter for rebar control variate.
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)
        posterior_params = encoder(encoder_params, batch)  # BxD

        def concrete(x):
            return jax.nn.sigmoid(x * temperature)

        # Sampling
        sampling_key, control_variate_key = random.split(prng_key)

        # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples)
        relaxed_samples = binary_relaxed.sample(posterior_params, sampling_key,
                                                num_samples)

        posterior_samples = np.heaviside(relaxed_samples, 0)
        # concrete_samples = jax.nn.sigmoid(relaxed_samples * temperature)

        conditional_relaxed_samples = binary_relaxed.conditional_sample(
            posterior_params, posterior_samples, control_variate_key)
        # cv_samples = jax.nn.sigmoid(conditional_relaxed_samples * temperature)
        # frozen_cv_samples = binary_relaxed.conditional_sample(lax.stop_gradient(posterior_params), posterior_samples,
        #                                                       control_variate_key)
        # frozen_cv_samples = jax.nn.sigmoid(lax.stop_gradient(conditional_relaxed_samples) * temperature)

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        def elbo_samples(samples, stop_grads=False):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD

            if stop_grads:
                log_posterior = np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(samples,
                                       lax.stop_gradient(posterior_params)),
                                       axis=(1, 2))
            else:
                log_posterior = np.sum(vmap(bernoulli.logpmf,
                                            in_axes=(0,
                                                     None))(samples,
                                                            posterior_params),
                                       axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(posterior_samples)
        concrete_evaluation = elbo_samples(concrete(relaxed_samples))
        cv_evaluation = elbo_samples((concrete(conditional_relaxed_samples)))
        frozen_cv_evaluation = elbo_samples(
            concrete(lax.stop_gradient(conditional_relaxed_samples)), True)

        obj = (lax.stop_gradient(elbo_evaluation) -
               cv_coeff * frozen_cv_evaluation) * log_posterior

        obj += cv_coeff * (concrete_evaluation - cv_evaluation)

        return -np.mean(obj, axis=0) / batch.shape[0]
Example #23
0
def logsumexp(x, axis=0):
    # TODO: remove when https://github.com/google/jax/pull/2260 merged upstream
    x_max = lax.stop_gradient(np.max(x, axis=axis, keepdims=True))
    return np.log(np.sum(np.exp(x - x_max),
                         axis=axis)) + x_max.squeeze(axis=axis)
Example #24
0
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
    """Piecewise-Constant PDF sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, num_bins].
    num_samples: int, the number of samples.
    randomized: bool, use randomized samples.

  Returns:
    z_samples: jnp.ndarray(float32), [batch_size, num_samples].
  """
    # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
    # avoids NaNs when the input is zeros or small, but has no effect otherwise.
    eps = 1e-5
    weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
    padding = jnp.maximum(0, eps - weight_sum)
    weights += padding / weights.shape[-1]
    weight_sum += padding

    # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
    # starts with exactly 0 and ends with exactly 1.
    pdf = weights / weight_sum
    cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))
    cdf = jnp.concatenate([
        jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,
        jnp.ones(list(cdf.shape[:-1]) + [1])
    ],
                          axis=-1)

    # Draw uniform samples.
    if randomized:
        # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
        u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
    else:
        # Match the behavior of random.uniform() by spanning [0, 1-eps].
        u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples)
        u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])

    # Identify the location in `cdf` that corresponds to a random sample.
    # The final `True` index in `mask` will be the start of the sampled interval.
    mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]

    def find_interval(x):
        # Grab the value where `mask` switches from True to False, and vice versa.
        # This approach takes advantage of the fact that `x` is sorted.
        x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]),
                     -2)
        x1 = jnp.min(
            jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
        return x0, x1

    bins_g0, bins_g1 = find_interval(bins)
    cdf_g0, cdf_g1 = find_interval(cdf)

    t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
    samples = bins_g0 + t * (bins_g1 - bins_g0)

    # Prevent gradient from backprop-ing through `samples`.
    return lax.stop_gradient(samples)
Example #25
0
    def relax_objective(encoder_params,
                        decoder_params,
                        surrogate_params,
                        batch,
                        prng_key,
                        num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param surrogate_params: surrogate parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        posterior_params = encoder(encoder_params, batch)  # BxD

        # Sampling
        sampling_key, conditional_sampling_key = random.split(prng_key)

        # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples)
        unconditional_samples = binary_relaxed.sample(posterior_params,
                                                      sampling_key,
                                                      num_samples)
        posterior_samples = np.heaviside(unconditional_samples, 0)
        conditional_samples = binary_relaxed.conditional_sample(
            posterior_params, posterior_samples, conditional_sampling_key)

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        def elbo_samples(samples):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD
            log_posterior = np.sum(vmap(bernoulli.logpmf,
                                        in_axes=(0, None))(samples,
                                                           posterior_params),
                                   axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(posterior_samples)

        obj = (lax.stop_gradient(elbo_evaluation) -
               np.sum(surrogate(surrogate_params,
                                lax.stop_gradient(conditional_samples)),
                      axis=1).squeeze()) * log_posterior

        obj += np.sum(surrogate(surrogate_params, unconditional_samples) -
                      surrogate(surrogate_params, conditional_samples),
                      axis=1).squeeze()

        return -np.mean(obj, axis=0) / batch.shape[0]
Example #26
0
    def relax_plus_rebar_objective(encoder_params,
                                   decoder_params,
                                   surrogate_params,
                                   log_temperature,
                                   batch,
                                   prng_key,
                                   num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param surrogate_params: surrogate parameters (list)
        :param log_temperature: log of inverse temperature
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)

        def concrete(x):
            return jax.nn.sigmoid(x * temperature)

        posterior_params = encoder(encoder_params, batch)  # BxD

        # Sampling
        sampling_key, conditional_sampling_key = random.split(prng_key)

        # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples)
        unconditional_samples = binary_relaxed.sample(posterior_params,
                                                      sampling_key,
                                                      num_samples)
        posterior_samples = np.heaviside(unconditional_samples, 0)
        conditional_samples = binary_relaxed.conditional_sample(
            posterior_params, posterior_samples, conditional_sampling_key)

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        def elbo_samples(samples, stop_grads=False):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD

            if stop_grads:
                log_posterior = np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(samples,
                                       lax.stop_gradient(posterior_params)),
                                       axis=(1, 2))
            else:
                log_posterior = np.sum(vmap(bernoulli.logpmf,
                                            in_axes=(0,
                                                     None))(samples,
                                                            posterior_params),
                                       axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(posterior_samples)
        unconditional_evaluation = elbo_samples(
            concrete(unconditional_samples))
        conditional_evaluation = elbo_samples(concrete(conditional_samples))
        frozen_conditional_evaluation = elbo_samples(
            concrete(lax.stop_gradient(conditional_samples)), True)

        obj = (lax.stop_gradient(elbo_evaluation) -
               frozen_conditional_evaluation -
               np.sum(surrogate(surrogate_params,
                                lax.stop_gradient(conditional_samples)),
                      axis=1).squeeze()) * log_posterior

        obj += unconditional_evaluation + np.sum(surrogate(
            surrogate_params, unconditional_samples),
                                                 axis=1).squeeze()
        obj -= conditional_evaluation + np.sum(surrogate(
            surrogate_params, conditional_samples),
                                               axis=1).squeeze()

        return -np.mean(obj, axis=0) / batch.shape[0]
Example #27
0
def neighbor_list(displacement_or_metric: DisplacementOrMetricFn,
                  box_size: Box,
                  r_cutoff: float,
                  dr_threshold: float,
                  capacity_multiplier: float = 1.25,
                  disable_cell_list: bool = False,
                  mask_self: bool = True,
                  custom_mask_function: Optional[MaskFn] = None,
                  fractional_coordinates: bool = False,
                  format: NeighborListFormat = NeighborListFormat.Dense,
                  **static_kwargs) -> NeighborFn:
    """Returns a function that builds a list neighbors for collections of points.

  Neighbor lists must balance the need to be jit compatable with the fact that
  under a jit the maximum number of neighbors cannot change (owing to static
  shape requirements). To deal with this, our `neighbor_list` returns a
  `NeighborListFns` object that contains two functions: 1)
  `neighbor_fn.allocate` create a new neighbor list and 2) `neighbor_fn.update`
  updates an existing neighbor list. Neighbor lists themselves additionally
  have a convenience `update` member function.

  Note that allocation of a new neighbor list cannot be jit compiled since it
  uses the positions to infer the maximum number of neighbors (along with
  additional space specified by the `capacity_multiplier`). Updating the
  neighbor list can be jit compiled; if the neighbor list capacity is not
  sufficient to store all the neighbors, the `did_buffer_overflow` bit
  will be set to `True` and a new neighbor list will need to be reallocated.

  Here is a typical example of a simulation loop with neighbor lists:

  >>> init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
  >>> exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3)
  >>>
  >>> nbrs = neighbor_fn.allocate(R)
  >>> state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx)
  >>>
  >>> def body_fn(i, state):
  >>>   state, nbrs = state
  >>>   nbrs = nbrs.update(state.position)
  >>>   state = apply_fn(state, neighbor_idx=nbrs.idx)
  >>>   return state, nbrs
  >>>
  >>> step = 0
  >>> for _ in range(20):
  >>>   new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs))
  >>>   if nbrs.did_buffer_overflow:
  >>>     nbrs = neighbor_fn.allocate(state.position)
  >>>   else:
  >>>     state = new_state
  >>>     step += 1

  Args:
    displacement: A function `d(R_a, R_b)` that computes the displacement
      between pairs of points.
    box_size: Either a float specifying the size of the box or an array of
      shape [spatial_dim] specifying the box size in each spatial dimension.
    r_cutoff: A scalar specifying the neighborhood radius.
    dr_threshold: A scalar specifying the maximum distance particles can move
      before rebuilding the neighbor list.
    capacity_multiplier: A floating point scalar specifying the fractional
      increase in maximum neighborhood occupancy we allocate compared with the
      maximum in the example positions.
    disable_cell_list: An optional boolean. If set to True then the neighbor
      list is constructed using only distances. This can be useful for
      debugging but should generally be left as False.
    mask_self: An optional boolean. Determines whether points can consider
      themselves to be their own neighbors.
    custom_mask_function: An optional function. Takes the neighbor array
      and masks selected elements. Note: The input array to the function is
      (n_particles, m) where the index of particle 1 is in index in the first
      dimension of the array, the index of particle 2 is given by the value in
      the array
    fractional_coordinates: An optional boolean. Specifies whether positions
      will be supplied in fractional coordinates in the unit cube, [0, 1]^d.
      If this is set to True then the box_size will be set to 1.0 and the
      cell size used in the cell list will be set to cutoff / box_size.
    format: The format of the neighbor list; see the NeighborListFormat enum
      for details about the different choices for formats. Defaults to `Dense`.
    **static_kwargs: kwargs that get threaded through the calculation of
      example positions.
  Returns:
    A NeighborListFns object that contains a method to allocate a new neighbor
    list and a method to update an existing neighbor list.
  """
    is_format_valid(format)
    box_size = lax.stop_gradient(box_size)
    r_cutoff = lax.stop_gradient(r_cutoff)
    dr_threshold = lax.stop_gradient(dr_threshold)

    box_size = f32(box_size)

    cutoff = r_cutoff + dr_threshold
    cutoff_sq = cutoff**2
    threshold_sq = (dr_threshold / f32(2))**2
    metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric)

    cell_size = cutoff
    if fractional_coordinates:
        cell_size = cutoff / box_size
        box_size = f32(1)

    use_cell_list = jnp.all(
        cell_size < box_size / 3.) and not disable_cell_list
    if use_cell_list:
        cl_fn = cell_list(box_size, cell_size, capacity_multiplier)

    @jit
    def candidate_fn(position: Array) -> Array:
        candidates = jnp.arange(position.shape[0])
        return jnp.broadcast_to(candidates[None, :],
                                (position.shape[0], position.shape[0]))

    @jit
    def cell_list_candidate_fn(cl: CellList, position: Array) -> Array:
        N, dim = position.shape

        idx = cl.id_buffer

        cell_idx = [idx]

        for dindex in _neighboring_cells(dim):
            if onp.all(dindex == 0):
                continue
            cell_idx += [_shift_array(idx, dindex)]

        cell_idx = jnp.concatenate(cell_idx, axis=-2)
        cell_idx = cell_idx[..., jnp.newaxis, :, :]
        cell_idx = jnp.broadcast_to(cell_idx,
                                    idx.shape[:-1] + cell_idx.shape[-2:])

        def copy_values_from_cell(value, cell_value, cell_id):
            scatter_indices = jnp.reshape(cell_id, (-1, ))
            cell_value = jnp.reshape(cell_value,
                                     (-1, ) + cell_value.shape[-2:])
            return value.at[scatter_indices].set(cell_value)

        neighbor_idx = jnp.zeros((N + 1, ) + cell_idx.shape[-2:], i32)
        neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx)
        return neighbor_idx[:-1, :, 0]

    @jit
    def mask_self_fn(idx: Array) -> Array:
        self_mask = idx == jnp.reshape(jnp.arange(idx.shape[0], dtype=i32),
                                       (idx.shape[0], 1))
        return jnp.where(self_mask, idx.shape[0], idx)

    @jit
    def prune_neighbor_list_dense(position: Array, idx: Array,
                                  **kwargs) -> Array:
        d = partial(metric_sq, **kwargs)
        d = space.map_neighbor(d)

        N = position.shape[0]
        neigh_position = position[idx]
        dR = d(position, neigh_position)

        mask = (dR < cutoff_sq) & (idx < N)
        out_idx = N * jnp.ones(idx.shape, i32)

        cumsum = jnp.cumsum(mask, axis=1)
        index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1)
        p_index = jnp.arange(idx.shape[0])[:, None]
        out_idx = out_idx.at[p_index, index].set(idx)
        max_occupancy = jnp.max(cumsum[:, -1])

        return out_idx[:, :-1], max_occupancy

    @jit
    def prune_neighbor_list_sparse(position: Array, idx: Array,
                                   **kwargs) -> Array:
        d = partial(metric_sq, **kwargs)
        d = space.map_bond(d)

        N = position.shape[0]
        sender_idx = jnp.broadcast_to(jnp.arange(N)[:, None], idx.shape)

        sender_idx = jnp.reshape(sender_idx, (-1, ))
        receiver_idx = jnp.reshape(idx, (-1, ))
        dR = d(position[sender_idx], position[receiver_idx])

        mask = (dR < cutoff_sq) & (receiver_idx < N)
        if format is NeighborListFormat.OrderedSparse:
            mask = mask & (receiver_idx < sender_idx)

        out_idx = N * jnp.ones(receiver_idx.shape, i32)

        cumsum = jnp.cumsum(mask)
        index = jnp.where(mask, cumsum - 1, len(receiver_idx) - 1)
        receiver_idx = out_idx.at[index].set(receiver_idx)
        sender_idx = out_idx.at[index].set(sender_idx)
        max_occupancy = cumsum[-1]

        return jnp.stack((receiver_idx[:-1], sender_idx[:-1])), max_occupancy

    def neighbor_list_fn(position: Array,
                         neighbors: Optional[NeighborList] = None,
                         extra_capacity: int = 0,
                         **kwargs) -> NeighborList:
        nbrs = neighbors

        def neighbor_fn(position_and_overflow, max_occupancy=None):
            position, overflow = position_and_overflow
            N = position.shape[0]

            if use_cell_list:
                if neighbors is None:
                    cl = cl_fn.allocate(position,
                                        extra_capacity=extra_capacity)
                else:
                    cl = cl_fn.update(position, neighbors.cell_list_capacity)
                overflow = overflow | cl.did_buffer_overflow
                idx = cell_list_candidate_fn(cl, position)
                cl_capacity = cl.cell_capacity
            else:
                cl_capacity = None
                idx = candidate_fn(position)

            if mask_self:
                idx = mask_self_fn(idx)
            if custom_mask_function is not None:
                idx = custom_mask_function(idx)

            if is_sparse(format):
                idx, occupancy = prune_neighbor_list_sparse(
                    position, idx, **kwargs)
            else:
                idx, occupancy = prune_neighbor_list_dense(
                    position, idx, **kwargs)

            if max_occupancy is None:
                _extra_capacity = (extra_capacity if not is_sparse(format) else
                                   N * extra_capacity)
                max_occupancy = int(occupancy * capacity_multiplier +
                                    _extra_capacity)
                if max_occupancy > position.shape[0] and not is_sparse(format):
                    max_occupancy = position.shape[0]
                if max_occupancy > occupancy:
                    padding = max_occupancy - occupancy
                    pad = N * jnp.ones(
                        (idx.shape[0], padding), dtype=idx.dtype)
                    idx = jnp.concatenate([idx, pad], axis=1)
            idx = idx[:, :max_occupancy]
            update_fn = (neighbor_list_fn
                         if neighbors is None else neighbors.update_fn)
            return NeighborList(idx, position,
                                overflow | (occupancy >= max_occupancy),
                                cl_capacity, max_occupancy, format, update_fn)  # pytype: disable=wrong-arg-count

        if nbrs is None:
            return neighbor_fn((position, False))

        neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy)

        d = partial(metric_sq, **kwargs)
        d = vmap(d)
        return lax.cond(
            jnp.any(d(position, nbrs.reference_position) > threshold_sq),
            (position, nbrs.did_buffer_overflow), neighbor_fn, nbrs,
            lambda x: x)

    def allocate_fn(position: Array,
                    extra_capacity: int = 0,
                    **kwargs) -> NeighborList:
        return neighbor_list_fn(position,
                                extra_capacity=extra_capacity,
                                **kwargs)

    def update_fn(position: Array, neighbors: NeighborList,
                  **kwargs) -> NeighborList:
        return neighbor_list_fn(position, neighbors, **kwargs)

    return NeighborListFns(allocate_fn, update_fn)  # pytype: disable=wrong-arg-count
Example #28
0
def solve(problem: NlpProblem,
          x0=None,
          solver=gmres,
          precond_fn=None,
          eps=1e-4,
          max_iters=1000,
          outer_callback=None,
          verbose=True):
    """Solves a nonlinear programming problem by solving the KKT system

    Args:
        problem (NlpProblem): Instantiated NLP problem
        x0 (ndarray): Initial guess for solution of NLP problem (default None)
        solver (callable): Function that solves the KKT system (default gmres)
        precond_fn (callable): Function that generates a preconditioner for the KKT system (default None)
        eps (float): Tolerance for stopping condition concerning norm of solution (default 1e-4)
        max_iters (int): Maximum number of newton method iterations (default 1000)
        outer_callback (callable): Callback on newton method iterations (default None)
        verbose (bool): Verbosity on/off (default False)
    Returns:
        xstar (ndarray): Optimized solution to NLP problem
    """

    m, n = problem.nconstraints, problem.nvars

    if x0 is None:
        x0 = jnp.zeros(n)

    # Init arrays
    x = jnp.array(x0)
    lam = jnp.zeros(m)
    z = jnp.zeros(n + m)

    for it in range(max_iters):

        if verbose:
            print(f"\n--iter: {it+1}")

        gx, cx = problem.g(x), problem.c(x)

        K = problem.KKT(x, lam)

        b = jnp.block([gx, cx])

        if precond_fn:
            M = precond_fn(problem, x, lam)
        else:
            M = None

        z = solver(K, b, stop_gradient(z), M=M)[0]

        p = -z[:n]
        lam = z[n:]

        # TODO - line search
        x += 0.7 * p

        if outer_callback:
            outer_callback(x, lam)

        if jnp.linalg.norm(p) < eps and jnp.linalg.norm(cx) < eps:
            break

    if verbose:
        print(f"Optimized in {it+1} iterations")

    return x
Example #29
0
    self._name = name

  def __enter__(self):
    return self._name

  def __exit__(self, type_arg, value_arg, traceback_arg):
    return False  # False values do not suppress exceptions.


newaxis = np.newaxis

if JAX_MODE:
  from jax import lax  # pylint: disable=g-import-not-at-top
  stop_gradient = utils.copy_docstring(
      'tf.stop_gradient',
      lambda input, name=None: lax.stop_gradient(input))
else:
  stop_gradient = utils.copy_docstring(
      'tf.stop_gradient',
      lambda input, name=None: np.array(input))


def _convert_tensorshape_to_tensor(value, dtype=None):
  """Copied from TF's TensorShape conversion."""
  if not value.is_fully_defined():
    raise ValueError(
        'Cannot convert a partially known TensorShape to a Tensor: {}'.format(
            value))
  value_list = value.as_list()
  int64_value = 0
  for dim in value_list:
Example #30
0
def _clamp_preserve_gradients(x, min, max):
    return x + stop_gradient(np.clip(x, a_min=min, a_max=max) - x)