def _get_cached_fn(self): if self._jit_model_args: args, kwargs = (None, ), (None, ) else: args = tree_map(lambda x: _hashable(x), self._args) kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(self._kwargs.items()))) key = args + kwargs try: fn = self._cache.get(key, None) # If unhashable arguments are provided, proceed normally # without caching except TypeError: fn, key = None, None if fn is None: if self._jit_model_args: fn = partial(_sample_fn_jit_args, sampler=self.sampler) else: fn = partial(_sample_fn_nojit_args, sampler=self.sampler, args=self._args, kwargs=self._kwargs) if key is not None: self._cache[key] = fn return fn
def get_potential_fn(rng_key, model, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used to constrain unconstrained samples (e.g. those returned by HMC) to values that lie within the site's support. """ model_kwargs = {} if model_kwargs is None else model_kwargs seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) inv_transforms = {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): inv_transforms[k] = transform.parts[0] has_transformed_dist = True else: inv_transforms[k] = transform if dynamic_args: def potential_fn(*args, **kwargs): return jax.partial(potential_energy, model, inv_transforms, args, kwargs) if has_transformed_dist: def constrain_fun(*args, **kwargs): return jax.partial(constrain_fn, model, inv_transforms, args, kwargs) else: def constrain_fun(*args, **kwargs): return jax.partial(transform_fn, inv_transforms) else: potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) if has_transformed_dist: constrain_fun = jax.partial(constrain_fn, model, inv_transforms, model_args, model_kwargs) else: constrain_fun = jax.partial(transform_fn, inv_transforms) return potential_fn, constrain_fun
def constrain_fun(*args, **kwargs): inv_transforms, has_transformed_dist = get_model_transforms( rng_key, model, args, kwargs) if has_transformed_dist: return jax.partial(constrain_fn, model, inv_transforms, args, kwargs) else: return jax.partial(transform_fn, inv_transforms)
def run(self, rng, *args, collect_fields=('z',), collect_warmup=False, init_params=None, **kwargs): """ Run the MCMC samplers and collect samples. :param random.PRNGKey rng: Random number generator key to be used for the sampling. :param args: Arguments to be provided to the :meth:`numpyro.mcmc.MCMCKernel.init` method. These are typically the arguments needed by the `model`. :param collect_fields: Fields from :data:`numpyro.mcmc.HMCState` to collect during the MCMC run. By default, only the latent sample sites `z` is collected. :type collect_fields: tuple or list :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults to `False`. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.mcmc.MCMCKernel.init` method. These are typically the keyword arguments needed by the `model`. """ chain_method = self.chain_method if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains: chain_method = 'sequential' warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.' ' Chains will be drawn sequentially. If you are running `mcmc` in CPU,' ' consider to disable XLA intra-op parallelism by setting the environment' ' flag "XLA_FLAGS=--xla_force_host_platform_device_count={}".' .format(self.num_chains, xla_bridge.device_count(), self.num_chains)) if init_params is not None and self.num_chains > 1: prototype_init_val = tree_flatten(init_params)[0][0] if np.shape(prototype_init_val)[0] != self.num_chains: raise ValueError('`init_params` must have the same leading dimension' ' as `num_chains`.') assert isinstance(collect_fields, (tuple, list)) self._collect_fields = collect_fields if self.num_chains == 1: samples_flat = self._single_chain_mcmc((rng, init_params), collect_fields, collect_warmup, args, kwargs) samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat) else: rngs = random.split(rng, self.num_chains) partial_map_fn = partial(self._single_chain_mcmc, collect_fields=collect_fields, collect_warmup=collect_warmup, args=args, kwargs=kwargs) if chain_method == 'sequential': map_fn = partial(lax.map, partial_map_fn) elif chain_method == 'parallel': map_fn = pmap(partial_map_fn) elif chain_method == 'vectorized': map_fn = vmap(partial_map_fn) else: raise ValueError('Only supporting the following methods to draw chains:' ' "sequential", "parallel", or "vectorized"') samples = map_fn((rngs, init_params)) samples_flat = tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), samples) self._samples = samples self._samples_flat = samples_flat
def InitRandomRotation(rng: PRNGKey, jitted=False): # create marginal functions key = rng f = jax.partial( get_random_rotation, return_params=True, ) f_slim = jax.partial( get_random_rotation, return_params=False, ) if jitted: f = jax.jit(f) f_slim = jax.jit(f_slim) def init_params(inputs, **kwargs): key_, rng = kwargs.get("rng", jax.random.split(key, 2)) _, params = f(rng, inputs) return params def params_and_transform(inputs, **kwargs): key_, rng = kwargs.get("rng", jax.random.split(key, 2)) outputs, params = f(rng, inputs) return outputs, params def transform(inputs, **kwargs): key_, rng = kwargs.get("rng", jax.random.split(key, 2)) outputs = f_slim(inputs) return outputs def bijector(inputs, **kwargs): params = init_params(inputs, **kwargs) bijector = Rotation(rotation=params.rotation, ) return bijector def bijector_and_transform(inputs, **kwargs): print(inputs.shape) outputs, params = params_and_transform(inputs, **kwargs) bijector = Rotation(rotation=params.rotation, ) return outputs, bijector return InitLayersFunctions( bijector=bijector, bijector_and_transform=bijector_and_transform, transform=transform, params=init_params, params_and_transform=params_and_transform, )
def get_optimizer(model, hyperparams: ServerHyperParams): ffgb = FedAlgorithm( sampler=sampler, server_init=jax.partial(server_init, model, hyperparams), client_init=jax.jit(jax.partial(client_init, model, hyperparams)), client_step=jax.partial(client_step, model, hyperparams), client_end=None, server_step=jax.partial(server_step, model, hyperparams)) return ffgb
def get_potential_fn(rng_key, model, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used to constrain unconstrained samples (e.g. those returned by HMC) to values that lie within the site's support. """ if dynamic_args: def potential_fn(*args, **kwargs): inv_transforms, has_transformed_dist = get_model_transforms( rng_key, model, args, kwargs) return jax.partial(potential_energy, model, inv_transforms, args, kwargs) def constrain_fun(*args, **kwargs): inv_transforms, has_transformed_dist = get_model_transforms( rng_key, model, args, kwargs) if has_transformed_dist: return jax.partial(constrain_fn, model, inv_transforms, args, kwargs) else: return jax.partial(transform_fn, inv_transforms) else: inv_transforms, has_transformed_dist = get_model_transforms( rng_key, model, model_args, model_kwargs) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) if has_transformed_dist: constrain_fun = jax.partial(constrain_fn, model, inv_transforms, model_args, model_kwargs) else: constrain_fun = jax.partial(transform_fn, inv_transforms) return potential_fn, constrain_fun
def postprocess_fn(*args, **kwargs): inv_transforms, replay_model = get_model_transforms( rng_key, model, args, kwargs) if replay_model: return jax.partial(constrain_fn, model, inv_transforms, args, kwargs, return_deterministic=True) else: return jax.partial(transform_fn, inv_transforms)
def main(): data = get_classic(100_000) # plot data plot_joint( data[:1_000], "blue", "Original Data", kind="kde", save_name=str(Path(FIG_PATH).joinpath("joint_data.png")), ) # define marginal entropy function entropy_f = jax.partial(histogram_entropy, nbins=1_000, base=2) # define marginal uniformization function hist_transform_f = jax.partial(histogram_transform, nbins=1_000) n_iterations = 100 X_trans, loss = total_corr_f( np.array(data).block_until_ready(), marginal_uni=hist_transform_f, marginal_entropy=entropy_f, n_iterations=n_iterations, ) total_corr = np.sum(loss) * np.log(2) plot_info_loss( loss, n_layers=len(loss), save_name=str(Path(FIG_PATH).joinpath("info_loss.png")), ) print(f"Total Correlation: {total_corr}") X_plot = onp.array(X_trans) plot_joint( X_plot[:10_000], "blue", "Latent Space", kind="kde", save_name=str(Path(FIG_PATH).joinpath("joint_latent.png")), ) pass
def e_step( rng: PRNGSequence, actor_target_params: FrozenDict, critic_target_params: FrozenDict, max_action: float, action_dim: int, temp: float, eps_eta: float, state: jnp.ndarray, batch_size: int, action_sample_size: int, ) -> Tuple[optim.Optimizer, jnp.ndarray, jnp.ndarray]: """ The 'E-step' from the MPO paper. We calculate our weights, which correspond to the relative likelihood of obtaining the maximum reward for each of the sampled actions. We also take steps on our temperature parameter, which induces diversity in the weights. """ Q1, sampled_actions = sample_actions_and_evaluate( rng, actor_target_params, critic_target_params, max_action, action_dim, state, batch_size, action_sample_size, ) jac = jax.grad(dual, argnums=2) jac = jax.partial(jac, Q1, eps_eta) # use nonconvex optimizer to minimize the dual of the temperature parameter # we have direct access to the jacobian function with jax so we can take # advantage of it here this_dual = jax.partial(dual, Q1, eps_eta) bounds = [(1e-6, None)] res = minimize(this_dual, temp, jac=jac, method="SLSQP", bounds=bounds) temp = jax.lax.stop_gradient(res.x) # calculate the sample-based q distribution. we can think of these weights # as the relative likelihood of each of the sampled actions giving us the # maximum score when taken at the corresponding state. weights = jax.nn.softmax(Q1 / temp, axis=1) weights = jax.lax.stop_gradient(weights) weights = jnp.expand_dims(weights, axis=-1) return temp, weights, sampled_actions
def init_kernel(init_samples, num_warmup_steps, step_size=1.0, num_steps=None, adapt_step_size=True, adapt_mass_matrix=True, diag_mass=True, target_accept_prob=0.8, run_warmup=True, rng=PRNGKey(0)): step_size = float(step_size) nonlocal trajectory_length, momentum_generator, wa_update if num_steps is None: trajectory_length = 2 * math.pi else: trajectory_length = num_steps * step_size z = init_samples z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup_steps, find_reasonable_step_size=find_reasonable_ss, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, diag_mass=diag_mass, target_accept_prob=target_accept_prob) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.inverse_mass_matrix, rng) vv_state = vv_init(z, r) hmc_state = HMCState(vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., wa_state.step_size, wa_state.inverse_mass_matrix, rng_hmc) if run_warmup: hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update, (hmc_state, wa_state)) return hmc_state else: return hmc_state, wa_state, warmup_update
def mixture_gaussian_invcdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: x_invcdf (JaxArray) : log CDF for the mixture distribution """ # INITIALIZE BOUNDS init_lb = np.ones_like(means).max(axis=1) - 1_000.0 init_ub = np.ones_like(means).max(axis=1) + 1_000.0 # INITIALIZE FUNCTION f = jax.partial( mixture_gaussian_cdf, prior_logits=prior_logits, means=means, scales=scales, ) return bisection_search(f, x, init_lb, init_ub)
def _call_init(primitive, rng, submodule_params, params, jaxpr, consts, freevar_vals, in_vals, **kwargs): jaxpr, = jaxpr consts, = consts freevar_vals, = freevar_vals f = lu.wrap_init(partial(jc.eval_jaxpr, jaxpr, consts, freevar_vals)) return primitive.bind(f, *in_vals, **params), submodule_params
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform(transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid
def process_call(self, call_primitive, f, tracers, params): flat_inputs, submodule_params_iter = ApplyTrace.Tracer.merge(tracers) f = ApplyTrace._apply_subtrace(f, self.master, WrapHashably(submodule_params_iter)) flat_outs = call_primitive.bind(f, *flat_inputs, **params) return map(partial(ApplyTrace.Tracer, self, submodule_params_iter), flat_outs)
def build_loglikelihoods(model, args, observations): """Function to compute the loglikelihood contribution of each variable. """ loglikelihoods = jax.partial(mcx.log_prob_contributions(model), **observations, **args) return loglikelihoods
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(model_init).get_trace(*args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform params[site['name']] = transform.inv(site['value']) self.constrain_fn = jax.partial(transform_fn, inv_transforms) return SVIState(self.optim.init(params), rng_key)
def initialize_model(rng, model, *model_args, init_strategy='uniform', **model_kwargs): """ Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng: random number generator seed to sample from the prior. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param str init_strategy: initialization strategy - `uniform` initializes the unconstrained parameters by drawing from a `Uniform(-2, 2)` distribution (as used by Stan), whereas `prior` initializes the parameters by sampling from the prior for each of the sample sites. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`) `init_params` are values from the prior used to initiate MCMC. `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ model = seed(model, rng) model_trace = trace(model).get_trace(*model_args, **model_kwargs) sample_sites = { k: v for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] } inv_transforms = { k: biject_to(v['fn'].support) for k, v in sample_sites.items() } prior_params = constrain_fn( inv_transforms, {k: v['value'] for k, v in sample_sites.items()}, invert=True) if init_strategy == 'uniform': init_params = {} for k, v in prior_params.items(): rng, = random.split(rng, 1) init_params[k] = random.uniform(rng, shape=np.shape(v), minval=-2, maxval=2) elif init_strategy == 'prior': init_params = prior_params else: raise ValueError( 'initialize={} is not a valid initialization strategy.'.format( init_strategy)) return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \ jax.partial(constrain_fn, inv_transforms)
def module(name, nn, input_shape=None): """ Declare a :mod:`~jax.experimental.stax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.experimental.stax` constructor function. :param tuple input_shape: shape of the input taken by the neural network. :return: a `apply_fn` with bound parameters that takes an array as an input and returns the neural network transformed output array. """ module_key = name + '$params' nn_init, nn_apply = nn nn_params = param(module_key) if nn_params is None: if input_shape is None: raise ValueError('Valid value for `input_size` needed to initialize.') rng = numpyro.sample(name + '$rng', PRNGIdentity()) _, nn_params = nn_init(rng, input_shape) param(module_key, nn_params) return jax.partial(nn_apply, nn_params)
def __init__(self, nin, nclass, scales, filters, filters_max, pooling=objax.functional.max_pool_2d, **kwargs): """Creates ConvNet instance. Args: nin: number of channels in the input image. nclass: number of output classes. scales: number of pooling layers, each of which reduces spatial dimension by 2. filters: base number of convolution filters. Number of convolution filters is increased by 2 every scale until it reaches filters_max. filters_max: maximum number of filters. pooling: type of pooling layer. """ del kwargs def nf(scale): return min(filters_max, filters << scale) ops = [objax.nn.Conv2D(nin, nf(0), 3), objax.functional.leaky_relu] for i in range(scales): ops.extend([objax.nn.Conv2D(nf(i), nf(i), 3), objax.functional.leaky_relu, objax.nn.Conv2D(nf(i), nf(i + 1), 3), objax.functional.leaky_relu, jax.partial(pooling, size=2, strides=2)]) ops.extend([objax.nn.Conv2D(nf(scales), nclass, 3), self._mean_reduce]) super().__init__(ops)
def solve( self, graph: "StackedFactorGraph", initial_assignments: VariableAssignments, ) -> VariableAssignments: """Run MAP inference on a factor graph.""" # Initialize assignments = initial_assignments cost, residual_vector = graph.compute_cost(assignments) state = self._initialize_state(graph, initial_assignments) # Optimization state = jax.lax.while_loop( cond_fun=lambda state: jnp.logical_and( jnp.logical_not(state.done), state.iterations < self. max_iterations), body_fun=jax.partial(self._step, graph), init_val=state, ) self._hcb_print( lambda i, max_i, cost: f"Terminated @ iteration #{i}/{max_i}: cost={str(cost).ljust(15)}", i=state.iterations, max_i=self.max_iterations, cost=state.cost, ) return state.assignments
def batched(*batched_args): args = tree_map(lambda x: x[0], batched_args) params = Parameter( lambda key: unbatched_model.init_parameters(*args, key=key), 'model')() batched_apply = vmap(partial(unbatched_model.apply, params), batch_dim) return batched_apply(*batched_args)
def log_likelihood_grad_bias(self, data, reward_model, bias_params): om, empirical_om, _ = self._ll_compute_oms(data, reward_model, bias_params) def blind_reward(biases, obs_matrix): """Compute blind reward for all states in such a way that Jax can differentiate with respect to the bias/masking vector. This is trivial for linear rewards, but harder for more general RewardModels.""" blind_obs_mat = obs_matrix * biases assert blind_obs_mat.shape == obs_matrix.shape return reward_model.out(blind_obs_mat) # compute gradient of reward in each state w.r.t. biases # (we do this separately for each input) blind_rew_grad_fn = jax.grad(blind_reward) lifted_blind_rew_grad_fn = jax.vmap( jax.partial(blind_rew_grad_fn, bias_params)) lifted_grads = lifted_blind_rew_grad_fn(self.env.observation_matrix) empirical_grad_term = jnp.sum( empirical_om[:, None] * lifted_grads, axis=0) pi_grad_term = jnp.sum(om[:, None] * lifted_grads, axis=0) grads = empirical_grad_term - pi_grad_term return grads
def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): """Return differentially private gradients for params, evaluated on batch.""" def _clipped_grad(params, single_example_batch): """Evaluate gradient for a single-example batch and clip its grad norm.""" grads = grad_loss(params, single_example_batch) nonempty_grads, tree_def = tree_util.tree_flatten(grads) total_grad_norm = np.linalg.norm( [np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = stop_gradient(np.amax((total_grad_norm / l2_norm_clip, 1.))) normalized_nonempty_grads = [g / divisor for g in nonempty_grads] return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads) px_clipped_grad_fn = vmap(partial(_clipped_grad, params)) std_dev = l2_norm_clip * noise_multiplier noise_ = lambda n: n + std_dev * random.normal(rng, n.shape) normalize_ = lambda n: n / float(batch_size) tree_map = tree_util.tree_map sum_ = lambda n: np.sum(n, 0) # aggregate aggregated_clipped_grads = tree_map(sum_, px_clipped_grad_fn(batch)) noised_aggregated_clipped_grads = tree_map(noise_, aggregated_clipped_grads) normalized_noised_aggregated_clipped_grads = (tree_map( normalize_, noised_aggregated_clipped_grads)) return normalized_noised_aggregated_clipped_grads
def fit_transform(self, X): self.n_features = X.shape[1] # initialize parameter storage params = [] losses = [] i_layer = 0 # loop through while i_layer < self.max_layers: loss = jax.partial(self.loss_f, X=X) # fix info criteria X, block_params = self.block_forward(X) info_red = loss(Y=X) # append Parameters params.append(block_params) losses.append(info_red) i_layer += 1 self.n_layers = i_layer self.params = params self.info_loss = np.array(losses) return X
def grads(self, inputs): in_grad_partial = jax.partial(self._net_grads, self._net_params) grad_vmap = jax.vmap(in_grad_partial) rich_grads = grad_vmap(inputs) flat_grads = np.asarray(self._flatten_batch(rich_grads)) assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0] return flat_grads
def _parse_nn_pepo_obc(C, D, vL, vR, vB, vT, lx, ly): assert (lx > 2) and (ly > 2 ) # (otherwise there is no bulk, to put the Ds in) x_d = lx // 2 y_d = ly // 2 # HORIZONTAL vL_C = np.tensordot(vL, C, [0, 2]) # (p,p*,r) C_vR = np.tensordot(vR, C, [0, 3]) # (p,p*,l) vB_D = np.tensordot(vB, D, [0, 4]) # (p,p*,l,r,u) D_vT = np.tensordot(vT, D, [0, 5]) # (p,p*,l,r,d) left_col = [ vL_C[:, :, :, None] ] + [vL_C[:, :, None, :, None]] * (ly - 2) + [vL_C[:, :, None, :]] # bottom C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,r,l) -> (p,p*,r,u,l) # bulk C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,u,l,d,r) # top C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,l,d,r) mid_col = [np.transpose(C, [0, 1, 3, 2])[:, :, :, None, :]] \ + [C[:, :, None, :, None, :]] * (ly - 2) \ + [C[:, :, :, None, :]] # vB_D: (p,p*,ijl) = (p,p*,lru) -> (p,p*,rul) # D: (p,p*,ijkl) -> (p,p*,likj) = (p,p*,uldr) # D_vT: (p,p*,ijk) = (p,p*,lrd) -> (p,p*,ldr) d_col = [np.transpose(vB_D, [0, 1, 3, 4, 2])] \ + [np.transpose(D, [0, 1, 5, 2, 4, 3])] * (ly - 2) \ + [np.transpose(D_vT, [0, 1, 2, 4, 3])] right_col = [ C_vR[:, :, None, :] ] + [C_vR[:, :, None, :, None]] * (ly - 2) + [C_vR[:, :, :, None]] tensors = [left_col] \ + [mid_col] * (x_d - 1) \ + [d_col] \ + [mid_col] * (lx - x_d - 2) \ + [right_col] pepo_hor = Pepo( tensors, OBC, False ) # even if the NnPepo is hermitian, the two separate Pepos could be not. # VERTICAL # rotate tensors clockwise # (p,p*,u,l,d,r) -> (p,p*,l,d,r,u) _rotate90 = partial(np.transpose, axes=[0, 1, 3, 4, 5, 2]) # tensor at new location (x,y) was at (-y-1,x) before tensors = [[tensors[-y - 1][0] for y in range(ly)]] \ + [[tensors[-1][x]] + [_rotate90(tensors[-y - 1][x]) for y in range(1, ly - 1)] + [tensors[0][x]] for x in range(1, lx - 1)] \ + [[tensors[-y - 1][-1] for y in range(ly)]] pepo_vert = Pepo( tensors, OBC, False ) # even if the NnPepo is hermitian, the two separate Pepos could be not. return pepo_hor, pepo_vert
def build_loglikelihoods(model, **kwargs): """Function to compute the loglikelihood contribution of each variable. """ artifact = compile_to_loglikelihoods(model.graph, model.namespace) logpdf = artifact.compiled_fn loglikelihoods = jax.partial(logpdf, **kwargs) return loglikelihoods
def _potential_energy(params): params_constrained = constrain_fn(inv_transforms, params) log_joint = jax.partial(log_density, model, model_args, model_kwargs)(params_constrained)[0] for name, t in inv_transforms.items(): log_joint = log_joint + np.sum( t.log_abs_det_jacobian(params[name], params_constrained[name])) return -log_joint
def _apply_subtrace(master, submodule_params, *vals): submodule_params = submodule_params.val trace = ApplyTrace(master, jc.cur_sublevel()) outs = yield map( partial(ApplyTrace.Tracer, trace, ApplyTrace.SubmoduleParamsIterator(submodule_params)), vals), {} out_tracers = map(trace.full_raise, outs) yield [t.val for t in out_tracers]