def _solve(a, b, sym_pos, lower): if not sym_pos: return np_linalg.solve(a, b) a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial(lax.custom_linear_solve, lambda x: lax_linalg._matvec_multiply(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def unsigned_int_scale(x, *, prec, axis=None): """Computes a scale s, s.t. 0 <= s * x <= 2**prec, where min(x) >= 0. Does not propagate gradients. Args: x: The input to be scaled. prec: Unsigned int precision of the scaled result. axis: Dimensions of input to consider for scaling. Returns: The scaling value. """ max_x = jnp.max(x, axis=axis, keepdims=True) if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING: max_x += jnp.finfo(jnp.float32).eps # to avoid div by 0 scale = unsigned_int_bound(prec) / max_x scale = lax.stop_gradient(scale) return scale
def signed_int_scale(x, *, prec, axis=None): """Computes a scale s, s.t. -2**(prec - 1) + 1 <= s * x <= 2**(prec - 1) - 1. Does not propagate gradients. Args: x: The input to be scaled. prec: Signed int precision of the scaled result. axis: Dimensions of input to consider for scaling. Returns: The scaling value. """ abs_max_x = max_abs_weights(x, axis=axis) if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING: abs_max_x += jnp.finfo(jnp.float32).eps # to avoid div by 0 scale = signed_int_bound(prec) / abs_max_x scale = lax.stop_gradient(scale) return scale
def SBL( X, y, prior_init=None, hyper_prior=((1e-6, 1e-6), (1e-6, 1e-6)), tol=1e-5, max_iter=1000, ): n_samples, n_features = X.shape if prior_init is None: prior_init = jnp.ones((n_features + 1, )) prior_init = jax.ops.index_update( prior_init, n_features, 1 / (jnp.var(y) + 1e-6)) # setting initial noise value # Adding term for gradient of loss prior_init = jnp.concatenate( [prior_init, jnp.ones((n_features, ))], axis=0) gram = jnp.dot(X.T, X) XT_y = jnp.dot(X.T, y) prior_params, iterations = fixed_point_solver( update, (X, y, gram, XT_y, hyper_prior), prior_init, lambda z_prev, z: jnp.linalg.norm(z_prev[-n_features:] - z[-n_features: ]) > tol, max_iter=max_iter, ) prior = stop_gradient(prior_params) loss, mn = evidence(X, y, gram, XT_y, prior[:-n_features], hyper_prior) metrics = ( iterations, jnp.linalg.norm(mn.squeeze() - prior[-n_features:]), prior[-n_features:], ) return loss, mn, prior[:-n_features], metrics
def _compute_loss_and_stats(params, model_out, use_elastic_loss=False): rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean() stats = { 'loss/rgb': rgb_loss, } loss = rgb_loss if use_elastic_loss: v_elastic_fn = jax.jit(vmap(vmap(compute_elastic_loss))) weights = lax.stop_gradient(model_out['weights']) jacobian = model_out['warp_jacobian'] # Pick the median point Jacobian. if elastic_reduce_method == 'median': depth_indices = model_utils.compute_depth_index(weights) jacobian = jnp.take_along_axis( # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. jacobian, depth_indices[..., None, None, None], axis=-3) # Compute loss using Jacobian. elastic_loss, elastic_residual = v_elastic_fn(jacobian) # Multiply weight if weighting by density. if elastic_reduce_method == 'weight': elastic_loss = weights * elastic_loss elastic_loss = elastic_loss.sum(axis=-1).mean() stats['loss/elastic'] = elastic_loss stats['residual/elastic'] = jnp.mean(elastic_residual) loss += scalar_params.elastic_loss_weight * elastic_loss if 'warp_jacobian' in model_out: jacobian = model_out['warp_jacobian'] jacobian_det = jnp.linalg.det(jacobian) jacobian_div = utils.jacobian_to_div(jacobian) jacobian_curl = utils.jacobian_to_curl(jacobian) stats['metric/jacobian_det'] = jnp.mean(jacobian_det) stats['metric/jacobian_div'] = jnp.mean(jacobian_div) stats['metric/jacobian_curl'] = jnp.mean( jnp.linalg.norm(jacobian_curl, axis=-1)) stats['loss/total'] = loss stats['metric/psnr'] = utils.compute_psnr(rgb_loss) return loss, stats
def loss_fn_pinn_bayes_mse_hyperprior(params, state, model, x, y, prior_params_mse=(0.0, 0.0)): variables = {"params": params, **state} (prediction, dt, theta, coeffs), updated_state = model.apply(variables, x, mutable=list(state.keys())) n_samples, n_features = theta.shape # Calculating precision of mse tau = precision(y, prediction, *prior_params_mse) p_mse, MSE = normal_LL(prediction, y, tau) p_mse += gamma_LL(tau, *prior_params_mse) # adding prior # Calculating precision of reg hyper_prior_nu = (n_samples / 2, n_samples / stop_gradient(tau)) nu = precision(dt, theta @ coeffs, *hyper_prior_nu) # calculates nu given gamma prior p_reg, reg = normal_LL(dt, theta @ coeffs, nu) p_reg += gamma_LL(nu, *hyper_prior_nu) # adding priorr loss = -(p_mse + p_reg) metrics = { "loss": loss, "p_mse": p_mse, "mse": MSE, "p_reg": p_reg, "reg": reg, "coeff": coeffs, "tau": tau, "nu": nu, } return loss, (updated_state, metrics, (prediction, dt, theta, coeffs))
def create_positive(cls, *, bounds, prec): """Create QuantOps for positive activations clipped to [0, bounds]. Args: bounds: The upper bound to clip the activations. prec: Unsigned int precision for the QuantOps. Returns: QuantOps for quantizing/dequantizing unsigned activations. """ initial_bounds = bounds bounds = jnp.asarray(bounds, SCALE_DTYPE) if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING: bounds += jnp.finfo(SCALE_DTYPE).eps # to avoid div by 0 scale = primitives.unsigned_int_bound(prec=prec) / bounds # NOTE: stop_gradient is needed here to prevent gradient flow through scale # when scale is not a constant, but computed as a function of activations. scale = lax.stop_gradient(scale) return cls(prec=prec, scale=scale, symmetric=False, bounds=initial_bounds)
def _solve(a, b): _check_solve_shapes(a, b) # Broadcast leading dimensions of b to the shape of a, as is required by # custom_linear_solve. out_shape = tuple(d_a if d_b == 1 else d_b for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape)) b = jnp.broadcast_to(b, out_shape) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, lambda x: _matvec_multiply(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def elbo_samples(samples, stop_grads=False): log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( samples, 0.5 * np.ones( (batch.shape[0], posterior_params.shape[-1]))), axis=(1, 2)) # SxBxD if stop_grads: log_posterior = np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(samples, lax.stop_gradient(posterior_params)), axis=(1, 2)) else: log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(samples, posterior_params), axis=(1, 2)) log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params, samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return elbo_samples
def Allreduce(x, op, comm=_MPI.COMM_WORLD, token=None): """ Performs the Allreduce operation `op` on the input `x` using the communicator `comm` which defaults to the world comunicator. An optional token can be passed, which is used to force jax to execute MPI operations in the correct order. Argumemnts: x: Array or scalar input. op: The reduction operation `MPI.Op` (e.g: MPI.SUM) comm: The communicator (defaults to MPI.COMM_WORLD) token: token to force a sequential order in the operations (default=None) Returns: res: result of the allreduce operation new_token: a new, modified token, that depends on this operation. This result can be ignored if result forces a data dependency. """ if token is None: token = create_token(stop_gradient(x)) op = wrap_as_hashable(op) comm = wrap_as_hashable(comm) return mpi_allreduce_p.bind(x, token, op=op, comm=comm)
def create_symmetric_fp( cls, *, bounds, fp_quant, ): """Create QuantOps for symmetric clipping to floating-point bounds. Args: bounds: The upper (and absolute lower) bound to clip the inputs. fp_quant: quantization floating-point specification of the target format. Returns: QuantOps for quantizing/dequantizing signed activations. """ if bounds is None: if fp_quant.is_scaled: raise ValueError( 'bounds can only be None if fp_quant.is_scaled is False.') return cls(prec=fp_quant, scale=None, symmetric=True, bounds=None) else: initial_bounds = bounds bounds = jnp.asarray(bounds, SCALE_DTYPE) if not DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING: bounds += jnp.finfo(SCALE_DTYPE).eps # to avoid log2(0) scale = jnp.exp2( -jnp.floor(jnp.log2(bounds))) # Scale to unit binade. # NOTE: stop_gradient is needed here to prevent gradient flow through # scale when scale is not a constant, but computed as a function of # activations or weights. scale = lax.stop_gradient(scale) return cls(prec=fp_quant, scale=scale, symmetric=True, bounds=initial_bounds)
def score_function_objective(encoder_params, decoder_params, batch, prng_key, num_samples=1): """ Computes the score function objective of a discrete VAE. The gradient of this objective matches the core function gradient of the ELBO. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param num_samples: number of samples """ encoder1, encoder2 = encoder encoder_params1, encoder_params2 = encoder_params decoder1, decoder2 = decoder decoder_params1, decoder_params2 = decoder_params first_layer_key, second_layer_key = random.split(prng_key) # Outer layer latents first_layer_params = encoder1(encoder_params1, batch) # BxD first_layer_samples = bernoulli.sample(first_layer_params, first_layer_key, num_samples) # Inner layer latents second_layer_params = encoder2(encoder_params2, first_layer_samples) second_layer_samples = bernoulli.sample(second_layer_params, second_layer_key, num_samples=1)[0, ...] # Inner layer prior log_prior = np.sum(bernoulli.logpmf( second_layer_samples, decoder2(decoder_params2, first_layer_samples)), axis=(1, 2)) # SxBxD # Outer layer prior log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, 0.5 * np.ones( (batch.shape[0], 200))), axis=(1, 2)) # SxBxD # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf(second_layer_samples, second_layer_params), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, first_layer_params), axis=(1, 2)) # Likelihood log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params[0], first_layer_samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior elbo_samples = lax.stop_gradient(elbo_samples) * log_posterior return -np.mean(elbo_samples, axis=0) / batch.shape[0]
def relax_plus_rebar_objective(encoder_params, decoder_params, surrogate_params, log_temperature, batch, prng_key, num_samples=1): """ Computes the REBAR objective function of a discrete VAE. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param surrogate_params: surrogate parameters (list) :param log_temperature: log of inverse temperature :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param num_samples: number of samples """ temperature = np.exp(log_temperature) def concrete(x): return jax.nn.sigmoid(x * temperature) encoder1, encoder2 = encoder encoder_params1, encoder_params2 = encoder_params decoder1, decoder2 = decoder decoder_params1, decoder_params2 = decoder_params # Sampling sampling_key, control_variate_key = random.split(prng_key) first_layer_sampling_key, second_layer_sampling_key = random.split( sampling_key) first_layer_control_variate_key, second_layer_control_variate_key = random.split( control_variate_key) # Outer Layer first_layer_params = encoder1(encoder_params1, batch) # BxD first_layer_relaxed_samples = binary_relaxed.sample( first_layer_params, first_layer_sampling_key, num_samples) first_layer_posterior_samples = np.heaviside( first_layer_relaxed_samples, 0) first_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample( first_layer_params, first_layer_posterior_samples, first_layer_control_variate_key) # Inner Layer second_layer_params = encoder2(encoder_params2, first_layer_posterior_samples) second_layer_relaxed_samples = binary_relaxed.sample( second_layer_params, second_layer_sampling_key, num_samples=1)[0, ...] second_layer_posterior_samples = np.heaviside( second_layer_relaxed_samples, 0) second_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample( second_layer_params, second_layer_posterior_samples, second_layer_control_variate_key) # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf(second_layer_posterior_samples, second_layer_params), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( first_layer_posterior_samples, first_layer_params), axis=(1, 2)) def elbo_samples(first_layer_samples, second_layer_samples, stop_grads=False): # Inner layer prior log_prior = np.sum(bernoulli.logpmf( second_layer_samples, decoder2(decoder_params2, first_layer_samples)), axis=(1, 2)) # SxBxD # Outer layer prior log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))), axis=(1, 2)) # SxBxD if stop_grads: # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf( second_layer_samples, lax.stop_gradient(second_layer_params)), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, lax.stop_gradient(first_layer_params)), axis=(1, 2)) else: # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf( second_layer_samples, second_layer_params), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, first_layer_params), axis=(1, 2)) # Likelihood log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params1, first_layer_samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return elbo_samples elbo_evaluation = elbo_samples(first_layer_posterior_samples, second_layer_posterior_samples) unconditional_evaluation = elbo_samples( concrete(first_layer_relaxed_samples), concrete(second_layer_relaxed_samples)) conditional_evaluation = elbo_samples( concrete(first_layer_conditional_relaxed_samples), concrete(second_layer_conditional_relaxed_samples)) frozen_conditional_evaluation = elbo_samples( concrete( lax.stop_gradient(first_layer_conditional_relaxed_samples)), concrete( lax.stop_gradient(second_layer_conditional_relaxed_samples)), stop_grads=True) relaxed_surrogate_inputs = np.concatenate( [first_layer_relaxed_samples, second_layer_relaxed_samples], axis=-1) conditional_relaxed_surrogate_inputs = np.concatenate([ first_layer_conditional_relaxed_samples, second_layer_conditional_relaxed_samples ], axis=-1) obj = (lax.stop_gradient(elbo_evaluation) - frozen_conditional_evaluation - np.sum(surrogate( surrogate_params, lax.stop_gradient(conditional_relaxed_surrogate_inputs)), axis=1).squeeze()) * log_posterior obj += unconditional_evaluation + np.sum(surrogate( surrogate_params, relaxed_surrogate_inputs), axis=1).squeeze() obj -= conditional_evaluation + np.sum(surrogate( surrogate_params, conditional_relaxed_surrogate_inputs), axis=1).squeeze() return -np.mean(obj, axis=0) / batch.shape[0]
def arm_objective(encoder_params, decoder_params, batch, prng_key, num_samples=1): """ Computes the ARM objective of a discrete VAE. The gradient of this objective matches the core function gradient of the ELBO. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param num_samples: number of samples """ encoder1, encoder2 = encoder encoder_params1, encoder_params2 = encoder_params decoder1, decoder2 = decoder decoder_params1, decoder_params2 = decoder_params bernoulli_key, uniform_key1, uniform_key2 = random.split(prng_key, num=3) first_layer_params = encoder1(encoder_params1, batch) # BxD first_layer_bernoulli_samples = bernoulli.sample( first_layer_params, bernoulli_key, num_samples) second_layer_params = encoder2(encoder_params2, first_layer_bernoulli_samples) first_layer_uniform_samples = random.uniform( key=uniform_key1, shape=(num_samples, *first_layer_params.shape)) first_layer_antithetic_samples = 1 - first_layer_uniform_samples second_layer_uniform_samples = random.uniform( key=uniform_key2, shape=second_layer_params.shape) second_layer_antithetic_samples = 1 - second_layer_uniform_samples first_layer_uniform_condition = first_layer_uniform_samples <= first_layer_params first_layer_antithetic_condition = first_layer_antithetic_samples <= first_layer_params second_layer_uniform_condition = second_layer_uniform_samples <= second_layer_params second_layer_antithetic_condition = second_layer_antithetic_samples <= second_layer_params def elbo_samples(first_layer_samples, second_layer_samples): log_prior = np.sum(bernoulli.logpmf( second_layer_samples, decoder2(decoder_params2, first_layer_bernoulli_samples)), axis=(1, 2)) # SxBxD log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))), axis=(1, 2)) # SxBxD log_posterior = np.sum(bernoulli.logpmf(second_layer_samples, second_layer_params), axis=(1, 2)) log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, first_layer_params), axis=(1, 2)) log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params1, first_layer_samples, batch) elbo = log_likelihood - log_posterior + log_prior return elbo sample_elbo = elbo_samples(first_layer_uniform_condition, second_layer_uniform_condition) antithetic_sample_elbo = elbo_samples( first_layer_antithetic_condition, second_layer_antithetic_condition) loss = lax.stop_gradient(antithetic_sample_elbo - sample_elbo) loss = loss * (np.sum(first_layer_params * (first_layer_uniform_samples - 0.5)) + np.sum(second_layer_params * (second_layer_uniform_samples - 0.5))) return -np.mean(loss, axis=0) / batch.shape[0]
def tunable_rebar_objective(encoder_params, decoder_params, batch, prng_key, cv_coeff=1.0, log_temperature=0.5, num_samples=1): """ Computes the REBAR objective function of a discrete VAE. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param cv_coeff: control variate coefficient :param log_temperature: log_temperature parameter for rebar control variate. :param num_samples: number of samples """ temperature = np.exp(log_temperature) encoder1, encoder2 = encoder encoder_params1, encoder_params2 = encoder_params decoder1, decoder2 = decoder decoder_params1, decoder_params2 = decoder_params # Sampling sampling_key, control_variate_key = random.split(prng_key) first_layer_sampling_key, second_layer_sampling_key = random.split( sampling_key) first_layer_control_variate_key, second_layer_control_variate_key = random.split( control_variate_key) # Outer Layer first_layer_params = encoder1(encoder_params1, batch) # BxD first_layer_relaxed_samples = binary_relaxed.sample( first_layer_params, first_layer_sampling_key, num_samples) first_layer_posterior_samples = np.heaviside( first_layer_relaxed_samples, 0) first_layer_concrete_samples = jax.nn.sigmoid( first_layer_relaxed_samples * temperature) first_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample( first_layer_params, first_layer_posterior_samples, first_layer_control_variate_key) first_layer_cv_samples = jax.nn.sigmoid( first_layer_conditional_relaxed_samples * temperature) first_layer_frozen_cv_samples = jax.nn.sigmoid( lax.stop_gradient(first_layer_conditional_relaxed_samples) * temperature) # Inner Layer second_layer_params = encoder2(encoder_params2, first_layer_posterior_samples) second_layer_relaxed_samples = binary_relaxed.sample( second_layer_params, second_layer_sampling_key, num_samples=1)[0, ...] second_layer_posterior_samples = np.heaviside( second_layer_relaxed_samples, 0) second_layer_concrete_samples = jax.nn.sigmoid( second_layer_relaxed_samples * temperature) second_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample( second_layer_params, second_layer_posterior_samples, second_layer_control_variate_key) second_layer_cv_samples = jax.nn.sigmoid( second_layer_conditional_relaxed_samples * temperature) second_layer_frozen_cv_samples = jax.nn.sigmoid( lax.stop_gradient(second_layer_conditional_relaxed_samples) * temperature) # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf(second_layer_posterior_samples, second_layer_params), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( first_layer_posterior_samples, first_layer_params), axis=(1, 2)) def elbo_samples(first_layer_samples, second_layer_samples, stop_grads=False): # Inner layer prior log_prior = np.sum(bernoulli.logpmf( second_layer_samples, decoder2(decoder_params2, first_layer_samples)), axis=(1, 2)) # SxBxD # Outer layer prior log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))), axis=(1, 2)) # SxBxD if stop_grads: # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf( second_layer_samples, lax.stop_gradient(second_layer_params)), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, lax.stop_gradient(first_layer_params)), axis=(1, 2)) else: # Inner layer posterior log_posterior = np.sum(bernoulli.logpmf( second_layer_samples, second_layer_params), axis=(1, 2)) # Outer layer posterior log_posterior += np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(first_layer_samples, first_layer_params), axis=(1, 2)) # Likelihood log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params1, first_layer_samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return elbo_samples elbo_evaluation = elbo_samples(first_layer_posterior_samples, second_layer_posterior_samples) concrete_evaluation = elbo_samples(first_layer_concrete_samples, second_layer_concrete_samples) cv_evaluation = elbo_samples(first_layer_cv_samples, second_layer_cv_samples) frozen_cv_evaluation = elbo_samples(first_layer_frozen_cv_samples, second_layer_frozen_cv_samples, True) obj = (lax.stop_gradient(elbo_evaluation) - cv_coeff * frozen_cv_evaluation) * log_posterior obj += cv_coeff * (concrete_evaluation - cv_evaluation) return -np.mean(obj, axis=0) / batch.shape[0]
def network(inputs: jnp.ndarray) -> jnp.ndarray: """Simple Q-network with randomized prior function.""" net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) x = hk.Flatten()(inputs) return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x))
def explicit_jacobian_solve(matvec, b): return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b))
def dot_product_attention(query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. This version is modified to move the softmax division after the dot product. Args: query: queries for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, dim1, dim2,..., dimN, num_heads, value_channels]`. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. """ assert key.shape[:-1] == value.shape[:-1] assert (query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis, ) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(np.delete(range(n), axis + (n - 1, ))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1, ) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1, ) + axis value = value.transpose(v_perm) query = query / jnp.sqrt(depth).astype(dtype) batch_dims_t = tuple(range(len(batch_dims))) attn_weights = lax.dot_general(query, key, (((n - 1, ), (n - 1, )), (batch_dims_t, batch_dims_t)), precision=precision) # apply attention bias: masking, droput, proximity bias, ect. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) decoding = attn_weights.shape[-2] != 256 if decoding: attn_weights = lax.exp(attn_weights - jax.scipy.special.logsumexp( attn_weights, axis=norm_dims, keepdims=True)) else: # move the division by the softmax denominator to after the dot product attn_weights = jnp.exp(attn_weights - lax.stop_gradient( jnp.max(attn_weights, axis=norm_dims, keepdims=True))) softmax_denominator = jnp.sum(attn_weights, axis=norm_dims, keepdims=False) attn_weights = attn_weights.astype(dtype) # apply dropout if not deterministic and dropout_rate > 0.: if dropout_rng is None: dropout_rng = nn.make_rng() keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = lax.dot_general(attn_weights, value, (wv_contracting_dims, (batch_dims_t, batch_dims_t)), precision=precision) if not decoding: # divide by the denominator of the attention softmax now, when the array is # O(N*H) rather than O(N^2) y = y / jnp.expand_dims(softmax_denominator, -1) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) return y
def explicit_jacobian_solve_aux(matvec, b): x = lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b)) return x, array_aux
def __call__(self, query, train=True, emb_mask=None, padding_mask=None): """Quantizes query array using VQ discretization bottleneck.""" flat_query = jnp.reshape(query, [-1, query.shape[-1]]) is_initialized = self.has_variable('vqvae', 'counter') if not is_initialized: self.ema_emb.value = self.vq_emb.value embeddings = self.vq_emb.value distances = (jnp.sum(flat_query**2, 1, keepdims=True) - 2 * jnp.dot(flat_query, embeddings) + jnp.sum(embeddings**2, 0, keepdims=True)) if emb_mask is not None: # Mask out some embeddings i.e. pad, BOS, EOS. # emb_mask shape == [batch_size, num_embeddings,] distances += INF * (1 - emb_mask) encoding_indices = jnp.argmin(distances, axis=1) encodings = jax.nn.one_hot(encoding_indices, self.num_embeddings, dtype=distances.dtype) encoding_indices = jnp.reshape(encoding_indices, query.shape[:-1]) quantized = embeddings.T[encoding_indices] e_latent_loss = jnp.mean( jnp.square(lax.stop_gradient(quantized) - query), axis=-1, keepdims=True) if train and is_initialized: self.counter.value += 1 dw = jnp.matmul(flat_query.T, encodings) decay = lax.convert_element_type(self.decay, dw.dtype) # Update ema_cluster_size and ema_emb one = jnp.ones([], dw.dtype) self.cluster_sizes.value = (self.cluster_sizes.value * self.decay + jnp.sum(encodings, axis=0) * (one - decay)) self.ema_emb.value = self.ema_emb.value * decay + dw * (one - decay) # Assign updated ema_emb to emb updated_ema_emb = self.ema_emb.value n = jnp.sum(self.cluster_sizes.value) updated_ema_cluster_size = ((self.cluster_sizes.value + EPS) / (n + self.num_embeddings * EPS) * n) normalised_updated_ema_w = ( updated_ema_emb / jnp.reshape(updated_ema_cluster_size, [1, -1])) self.vq_emb.value = normalised_updated_ema_w if padding_mask is not None: encoding_indices *= padding_mask e_latent_loss *= padding_mask[Ellipsis, None] loss = self.commitment_cost * e_latent_loss.sum() quantized = query + lax.stop_gradient(quantized - query) indices = lax.stop_gradient(encoding_indices).astype(jnp.int32) return { 'latents': quantized, 'loss': loss, 'latent_indices': indices, }
def tunable_rebar_objective(encoder_params, decoder_params, batch, prng_key, cv_coeff=1.0, log_temperature=0.5, num_samples=1): """ Computes the REBAR objective function of a discrete VAE. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param cv_coeff: control variate coefficient :param log_temperature: log_temperature parameter for rebar control variate. :param num_samples: number of samples """ temperature = np.exp(log_temperature) posterior_params = encoder(encoder_params, batch) # BxD def concrete(x): return jax.nn.sigmoid(x * temperature) # Sampling sampling_key, control_variate_key = random.split(prng_key) # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples) relaxed_samples = binary_relaxed.sample(posterior_params, sampling_key, num_samples) posterior_samples = np.heaviside(relaxed_samples, 0) # concrete_samples = jax.nn.sigmoid(relaxed_samples * temperature) conditional_relaxed_samples = binary_relaxed.conditional_sample( posterior_params, posterior_samples, control_variate_key) # cv_samples = jax.nn.sigmoid(conditional_relaxed_samples * temperature) # frozen_cv_samples = binary_relaxed.conditional_sample(lax.stop_gradient(posterior_params), posterior_samples, # control_variate_key) # frozen_cv_samples = jax.nn.sigmoid(lax.stop_gradient(conditional_relaxed_samples) * temperature) log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(posterior_samples, posterior_params), axis=(1, 2)) def elbo_samples(samples, stop_grads=False): log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( samples, 0.5 * np.ones( (batch.shape[0], posterior_params.shape[-1]))), axis=(1, 2)) # SxBxD if stop_grads: log_posterior = np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(samples, lax.stop_gradient(posterior_params)), axis=(1, 2)) else: log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(samples, posterior_params), axis=(1, 2)) log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params, samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return elbo_samples elbo_evaluation = elbo_samples(posterior_samples) concrete_evaluation = elbo_samples(concrete(relaxed_samples)) cv_evaluation = elbo_samples((concrete(conditional_relaxed_samples))) frozen_cv_evaluation = elbo_samples( concrete(lax.stop_gradient(conditional_relaxed_samples)), True) obj = (lax.stop_gradient(elbo_evaluation) - cv_coeff * frozen_cv_evaluation) * log_posterior obj += cv_coeff * (concrete_evaluation - cv_evaluation) return -np.mean(obj, axis=0) / batch.shape[0]
def logsumexp(x, axis=0): # TODO: remove when https://github.com/google/jax/pull/2260 merged upstream x_max = lax.stop_gradient(np.max(x, axis=axis, keepdims=True)) return np.log(np.sum(np.exp(x - x_max), axis=axis)) + x_max.squeeze(axis=axis)
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized): """Piecewise-Constant PDF sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. weights: jnp.ndarray(float32), [batch_size, num_bins]. num_samples: int, the number of samples. randomized: bool, use randomized samples. Returns: z_samples: jnp.ndarray(float32), [batch_size, num_samples]. """ # Pad each weight vector (only if necessary) to bring its sum to `eps`. This # avoids NaNs when the input is zeros or small, but has no effect otherwise. eps = 1e-5 weight_sum = jnp.sum(weights, axis=-1, keepdims=True) padding = jnp.maximum(0, eps - weight_sum) weights += padding / weights.shape[-1] weight_sum += padding # Compute the PDF and CDF for each weight vector, while ensuring that the CDF # starts with exactly 0 and ends with exactly 1. pdf = weights / weight_sum cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1)) cdf = jnp.concatenate([ jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf, jnp.ones(list(cdf.shape[:-1]) + [1]) ], axis=-1) # Draw uniform samples. if randomized: # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1. u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples]) else: # Match the behavior of random.uniform() by spanning [0, 1-eps]. u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples) u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples]) # Identify the location in `cdf` that corresponds to a random sample. # The final `True` index in `mask` will be the start of the sampled interval. mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None] def find_interval(x): # Grab the value where `mask` switches from True to False, and vice versa. # This approach takes advantage of the fact that `x` is sorted. x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2) x1 = jnp.min( jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2) return x0, x1 bins_g0, bins_g1 = find_interval(bins) cdf_g0, cdf_g1 = find_interval(cdf) t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) samples = bins_g0 + t * (bins_g1 - bins_g0) # Prevent gradient from backprop-ing through `samples`. return lax.stop_gradient(samples)
def relax_objective(encoder_params, decoder_params, surrogate_params, batch, prng_key, num_samples=1): """ Computes the REBAR objective function of a discrete VAE. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param surrogate_params: surrogate parameters (list) :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param num_samples: number of samples """ posterior_params = encoder(encoder_params, batch) # BxD # Sampling sampling_key, conditional_sampling_key = random.split(prng_key) # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples) unconditional_samples = binary_relaxed.sample(posterior_params, sampling_key, num_samples) posterior_samples = np.heaviside(unconditional_samples, 0) conditional_samples = binary_relaxed.conditional_sample( posterior_params, posterior_samples, conditional_sampling_key) log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(posterior_samples, posterior_params), axis=(1, 2)) def elbo_samples(samples): log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( samples, 0.5 * np.ones( (batch.shape[0], posterior_params.shape[-1]))), axis=(1, 2)) # SxBxD log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(samples, posterior_params), axis=(1, 2)) log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params, samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return elbo_samples elbo_evaluation = elbo_samples(posterior_samples) obj = (lax.stop_gradient(elbo_evaluation) - np.sum(surrogate(surrogate_params, lax.stop_gradient(conditional_samples)), axis=1).squeeze()) * log_posterior obj += np.sum(surrogate(surrogate_params, unconditional_samples) - surrogate(surrogate_params, conditional_samples), axis=1).squeeze() return -np.mean(obj, axis=0) / batch.shape[0]
def relax_plus_rebar_objective(encoder_params, decoder_params, surrogate_params, log_temperature, batch, prng_key, num_samples=1): """ Computes the REBAR objective function of a discrete VAE. :param encoder_params: encoder parameters (list) :param decoder_params: decoder parameters (list) :param surrogate_params: surrogate parameters (list) :param log_temperature: log of inverse temperature :param batch: batch of data (jax.numpy array) :param prng_key: PRNG key :param num_samples: number of samples """ temperature = np.exp(log_temperature) def concrete(x): return jax.nn.sigmoid(x * temperature) posterior_params = encoder(encoder_params, batch) # BxD # Sampling sampling_key, conditional_sampling_key = random.split(prng_key) # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples) unconditional_samples = binary_relaxed.sample(posterior_params, sampling_key, num_samples) posterior_samples = np.heaviside(unconditional_samples, 0) conditional_samples = binary_relaxed.conditional_sample( posterior_params, posterior_samples, conditional_sampling_key) log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(posterior_samples, posterior_params), axis=(1, 2)) def elbo_samples(samples, stop_grads=False): log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))( samples, 0.5 * np.ones( (batch.shape[0], posterior_params.shape[-1]))), axis=(1, 2)) # SxBxD if stop_grads: log_posterior = np.sum(vmap( bernoulli.logpmf, in_axes=(0, None))(samples, lax.stop_gradient(posterior_params)), axis=(1, 2)) else: log_posterior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(samples, posterior_params), axis=(1, 2)) log_likelihood = vmap(bernoulli_log_likelihood, in_axes=(None, 0, None))(decoder_params, samples, batch) elbo_samples = log_likelihood - log_posterior + log_prior return elbo_samples elbo_evaluation = elbo_samples(posterior_samples) unconditional_evaluation = elbo_samples( concrete(unconditional_samples)) conditional_evaluation = elbo_samples(concrete(conditional_samples)) frozen_conditional_evaluation = elbo_samples( concrete(lax.stop_gradient(conditional_samples)), True) obj = (lax.stop_gradient(elbo_evaluation) - frozen_conditional_evaluation - np.sum(surrogate(surrogate_params, lax.stop_gradient(conditional_samples)), axis=1).squeeze()) * log_posterior obj += unconditional_evaluation + np.sum(surrogate( surrogate_params, unconditional_samples), axis=1).squeeze() obj -= conditional_evaluation + np.sum(surrogate( surrogate_params, conditional_samples), axis=1).squeeze() return -np.mean(obj, axis=0) / batch.shape[0]
def neighbor_list(displacement_or_metric: DisplacementOrMetricFn, box_size: Box, r_cutoff: float, dr_threshold: float, capacity_multiplier: float = 1.25, disable_cell_list: bool = False, mask_self: bool = True, custom_mask_function: Optional[MaskFn] = None, fractional_coordinates: bool = False, format: NeighborListFormat = NeighborListFormat.Dense, **static_kwargs) -> NeighborFn: """Returns a function that builds a list neighbors for collections of points. Neighbor lists must balance the need to be jit compatable with the fact that under a jit the maximum number of neighbors cannot change (owing to static shape requirements). To deal with this, our `neighbor_list` returns a `NeighborListFns` object that contains two functions: 1) `neighbor_fn.allocate` create a new neighbor list and 2) `neighbor_fn.update` updates an existing neighbor list. Neighbor lists themselves additionally have a convenience `update` member function. Note that allocation of a new neighbor list cannot be jit compiled since it uses the positions to infer the maximum number of neighbors (along with additional space specified by the `capacity_multiplier`). Updating the neighbor list can be jit compiled; if the neighbor list capacity is not sufficient to store all the neighbors, the `did_buffer_overflow` bit will be set to `True` and a new neighbor list will need to be reallocated. Here is a typical example of a simulation loop with neighbor lists: >>> init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) >>> exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) >>> >>> nbrs = neighbor_fn.allocate(R) >>> state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx) >>> >>> def body_fn(i, state): >>> state, nbrs = state >>> nbrs = nbrs.update(state.position) >>> state = apply_fn(state, neighbor_idx=nbrs.idx) >>> return state, nbrs >>> >>> step = 0 >>> for _ in range(20): >>> new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs)) >>> if nbrs.did_buffer_overflow: >>> nbrs = neighbor_fn.allocate(state.position) >>> else: >>> state = new_state >>> step += 1 Args: displacement: A function `d(R_a, R_b)` that computes the displacement between pairs of points. box_size: Either a float specifying the size of the box or an array of shape [spatial_dim] specifying the box size in each spatial dimension. r_cutoff: A scalar specifying the neighborhood radius. dr_threshold: A scalar specifying the maximum distance particles can move before rebuilding the neighbor list. capacity_multiplier: A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions. disable_cell_list: An optional boolean. If set to True then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left as False. mask_self: An optional boolean. Determines whether points can consider themselves to be their own neighbors. custom_mask_function: An optional function. Takes the neighbor array and masks selected elements. Note: The input array to the function is (n_particles, m) where the index of particle 1 is in index in the first dimension of the array, the index of particle 2 is given by the value in the array fractional_coordinates: An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, [0, 1]^d. If this is set to True then the box_size will be set to 1.0 and the cell size used in the cell list will be set to cutoff / box_size. format: The format of the neighbor list; see the NeighborListFormat enum for details about the different choices for formats. Defaults to `Dense`. **static_kwargs: kwargs that get threaded through the calculation of example positions. Returns: A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list. """ is_format_valid(format) box_size = lax.stop_gradient(box_size) r_cutoff = lax.stop_gradient(r_cutoff) dr_threshold = lax.stop_gradient(dr_threshold) box_size = f32(box_size) cutoff = r_cutoff + dr_threshold cutoff_sq = cutoff**2 threshold_sq = (dr_threshold / f32(2))**2 metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric) cell_size = cutoff if fractional_coordinates: cell_size = cutoff / box_size box_size = f32(1) use_cell_list = jnp.all( cell_size < box_size / 3.) and not disable_cell_list if use_cell_list: cl_fn = cell_list(box_size, cell_size, capacity_multiplier) @jit def candidate_fn(position: Array) -> Array: candidates = jnp.arange(position.shape[0]) return jnp.broadcast_to(candidates[None, :], (position.shape[0], position.shape[0])) @jit def cell_list_candidate_fn(cl: CellList, position: Array) -> Array: N, dim = position.shape idx = cl.id_buffer cell_idx = [idx] for dindex in _neighboring_cells(dim): if onp.all(dindex == 0): continue cell_idx += [_shift_array(idx, dindex)] cell_idx = jnp.concatenate(cell_idx, axis=-2) cell_idx = cell_idx[..., jnp.newaxis, :, :] cell_idx = jnp.broadcast_to(cell_idx, idx.shape[:-1] + cell_idx.shape[-2:]) def copy_values_from_cell(value, cell_value, cell_id): scatter_indices = jnp.reshape(cell_id, (-1, )) cell_value = jnp.reshape(cell_value, (-1, ) + cell_value.shape[-2:]) return value.at[scatter_indices].set(cell_value) neighbor_idx = jnp.zeros((N + 1, ) + cell_idx.shape[-2:], i32) neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) return neighbor_idx[:-1, :, 0] @jit def mask_self_fn(idx: Array) -> Array: self_mask = idx == jnp.reshape(jnp.arange(idx.shape[0], dtype=i32), (idx.shape[0], 1)) return jnp.where(self_mask, idx.shape[0], idx) @jit def prune_neighbor_list_dense(position: Array, idx: Array, **kwargs) -> Array: d = partial(metric_sq, **kwargs) d = space.map_neighbor(d) N = position.shape[0] neigh_position = position[idx] dR = d(position, neigh_position) mask = (dR < cutoff_sq) & (idx < N) out_idx = N * jnp.ones(idx.shape, i32) cumsum = jnp.cumsum(mask, axis=1) index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = jnp.arange(idx.shape[0])[:, None] out_idx = out_idx.at[p_index, index].set(idx) max_occupancy = jnp.max(cumsum[:, -1]) return out_idx[:, :-1], max_occupancy @jit def prune_neighbor_list_sparse(position: Array, idx: Array, **kwargs) -> Array: d = partial(metric_sq, **kwargs) d = space.map_bond(d) N = position.shape[0] sender_idx = jnp.broadcast_to(jnp.arange(N)[:, None], idx.shape) sender_idx = jnp.reshape(sender_idx, (-1, )) receiver_idx = jnp.reshape(idx, (-1, )) dR = d(position[sender_idx], position[receiver_idx]) mask = (dR < cutoff_sq) & (receiver_idx < N) if format is NeighborListFormat.OrderedSparse: mask = mask & (receiver_idx < sender_idx) out_idx = N * jnp.ones(receiver_idx.shape, i32) cumsum = jnp.cumsum(mask) index = jnp.where(mask, cumsum - 1, len(receiver_idx) - 1) receiver_idx = out_idx.at[index].set(receiver_idx) sender_idx = out_idx.at[index].set(sender_idx) max_occupancy = cumsum[-1] return jnp.stack((receiver_idx[:-1], sender_idx[:-1])), max_occupancy def neighbor_list_fn(position: Array, neighbors: Optional[NeighborList] = None, extra_capacity: int = 0, **kwargs) -> NeighborList: nbrs = neighbors def neighbor_fn(position_and_overflow, max_occupancy=None): position, overflow = position_and_overflow N = position.shape[0] if use_cell_list: if neighbors is None: cl = cl_fn.allocate(position, extra_capacity=extra_capacity) else: cl = cl_fn.update(position, neighbors.cell_list_capacity) overflow = overflow | cl.did_buffer_overflow idx = cell_list_candidate_fn(cl, position) cl_capacity = cl.cell_capacity else: cl_capacity = None idx = candidate_fn(position) if mask_self: idx = mask_self_fn(idx) if custom_mask_function is not None: idx = custom_mask_function(idx) if is_sparse(format): idx, occupancy = prune_neighbor_list_sparse( position, idx, **kwargs) else: idx, occupancy = prune_neighbor_list_dense( position, idx, **kwargs) if max_occupancy is None: _extra_capacity = (extra_capacity if not is_sparse(format) else N * extra_capacity) max_occupancy = int(occupancy * capacity_multiplier + _extra_capacity) if max_occupancy > position.shape[0] and not is_sparse(format): max_occupancy = position.shape[0] if max_occupancy > occupancy: padding = max_occupancy - occupancy pad = N * jnp.ones( (idx.shape[0], padding), dtype=idx.dtype) idx = jnp.concatenate([idx, pad], axis=1) idx = idx[:, :max_occupancy] update_fn = (neighbor_list_fn if neighbors is None else neighbors.update_fn) return NeighborList(idx, position, overflow | (occupancy >= max_occupancy), cl_capacity, max_occupancy, format, update_fn) # pytype: disable=wrong-arg-count if nbrs is None: return neighbor_fn((position, False)) neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy) d = partial(metric_sq, **kwargs) d = vmap(d) return lax.cond( jnp.any(d(position, nbrs.reference_position) > threshold_sq), (position, nbrs.did_buffer_overflow), neighbor_fn, nbrs, lambda x: x) def allocate_fn(position: Array, extra_capacity: int = 0, **kwargs) -> NeighborList: return neighbor_list_fn(position, extra_capacity=extra_capacity, **kwargs) def update_fn(position: Array, neighbors: NeighborList, **kwargs) -> NeighborList: return neighbor_list_fn(position, neighbors, **kwargs) return NeighborListFns(allocate_fn, update_fn) # pytype: disable=wrong-arg-count
def solve(problem: NlpProblem, x0=None, solver=gmres, precond_fn=None, eps=1e-4, max_iters=1000, outer_callback=None, verbose=True): """Solves a nonlinear programming problem by solving the KKT system Args: problem (NlpProblem): Instantiated NLP problem x0 (ndarray): Initial guess for solution of NLP problem (default None) solver (callable): Function that solves the KKT system (default gmres) precond_fn (callable): Function that generates a preconditioner for the KKT system (default None) eps (float): Tolerance for stopping condition concerning norm of solution (default 1e-4) max_iters (int): Maximum number of newton method iterations (default 1000) outer_callback (callable): Callback on newton method iterations (default None) verbose (bool): Verbosity on/off (default False) Returns: xstar (ndarray): Optimized solution to NLP problem """ m, n = problem.nconstraints, problem.nvars if x0 is None: x0 = jnp.zeros(n) # Init arrays x = jnp.array(x0) lam = jnp.zeros(m) z = jnp.zeros(n + m) for it in range(max_iters): if verbose: print(f"\n--iter: {it+1}") gx, cx = problem.g(x), problem.c(x) K = problem.KKT(x, lam) b = jnp.block([gx, cx]) if precond_fn: M = precond_fn(problem, x, lam) else: M = None z = solver(K, b, stop_gradient(z), M=M)[0] p = -z[:n] lam = z[n:] # TODO - line search x += 0.7 * p if outer_callback: outer_callback(x, lam) if jnp.linalg.norm(p) < eps and jnp.linalg.norm(cx) < eps: break if verbose: print(f"Optimized in {it+1} iterations") return x
self._name = name def __enter__(self): return self._name def __exit__(self, type_arg, value_arg, traceback_arg): return False # False values do not suppress exceptions. newaxis = np.newaxis if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top stop_gradient = utils.copy_docstring( 'tf.stop_gradient', lambda input, name=None: lax.stop_gradient(input)) else: stop_gradient = utils.copy_docstring( 'tf.stop_gradient', lambda input, name=None: np.array(input)) def _convert_tensorshape_to_tensor(value, dtype=None): """Copied from TF's TensorShape conversion.""" if not value.is_fully_defined(): raise ValueError( 'Cannot convert a partially known TensorShape to a Tensor: {}'.format( value)) value_list = value.as_list() int64_value = 0 for dim in value_list:
def _clamp_preserve_gradients(x, min, max): return x + stop_gradient(np.clip(x, a_min=min, a_max=max) - x)