def sinkhorn_dual_solver(a, b, cost, epsilon = 1e-2, num_iters = 1000): """Runs Sinkhorn algorithm in log space for stability at small epsilon. Uses a fixed number of iterations to enable compatibility with jit. TODO(riannevdberg): implement dymanic num_iterations with jax control flows. Returns the regularized transport cost as in dual formaul 4.30 of https://arxiv.org/abs/1803.00567 Note that the unregularized transport cost for a finite iteration solution of the potentials f and g is a lower bound on the regularized is a lower bound on the regurized transport cost with a converged solution for f and g (see proposition 4.8 of arxiv paper). If you need to compute the coupling P^L from the potentials after a finite number of iterations L, you should put the coupling through a rounding operation (round_coupling) to ensure that the coupling is a doubly stochastic. Otherwise <C, P^L> is not a valid approximation of L_C(a, b). After applying the rounding_coupling function to P^L <-- round(P^L) you can then use < C, P^L> as a valid approximation. Args: a: np.ndarray<float>[n]: discrete probability distribution. The rows of the coupling matrix P must sum up to this vector. If a represents an empirical distribution with n samples, all entries should be equal to 1/n. b: np.ndarray<float>[m]: discrete probability distribution. The columns of the coupling matrix P must sum up to this vector. If b represents an empirical distribution with m samples, all entries should be equal to 1/m. cost: np.ndarray<float>[n, m]: the cost matrix cost[i,j] = c(x_i, y_j) where x_i and y_j are samples from distributions a and b respectively. epsilon: (float) the level of entropic regularization wanted. num_iters: (int32) the number of Sinkhorn iterations. Returns: transportation cost (eq. 4.48 of paper), coupling (which needs to be rounded still with round_coupling method if used to compute a loss or if you want to ensure it has the correct marginals a and b, error in column marginal. """ loga = jnp.expand_dims(jnp.log(a), axis=1) logb = jnp.expand_dims(jnp.log(b), axis=0) f = jnp.zeros_like(loga) # epsilon * log_u g = jnp.zeros_like(logb) # epsilon * log_v for _ in range(num_iters): # Note: If the update order is g before f, then check the error in b, # as this will be the largest error. If using the reverse error, then # check the error in a. # To carry out the logsumexp in a stable way we use the fact that # the matrix f + g - cost has all negative entries. We therefore use this # to add and subtract f and g in the respective updates in and outside the # logsumexp. g = epsilon * logb - epsilon * jax.scipy.special.logsumexp( (f + g - cost) / epsilon, axis=0, keepdims=True) + g f = epsilon * loga - epsilon * jax.scipy.special.logsumexp( (f + g - cost) / epsilon, axis=1, keepdims=True) + f # Compute error coupling = jnp.exp((f + g - cost) / epsilon) b_target = jnp.sum(coupling, axis=0) err = jnp.max(jnp.abs(b_target - b) / b, axis=None) # Compute unregularized cost according to eq. 4.48 of paper. # Note that if you want to compute the regularized cost of eq. 4.30 # this only requires subtracting epsilon, as the double sum # < e^f/eps, K e^g/eps > = 1 for updates like in this sinkhorn algorithm. transport_cost = jnp.sum(f * a) + jnp.sum(g * b) return transport_cost, coupling, err
def main(): total_secs = 10.0 gamma = 0.9 rng = random.PRNGKey(0) ### Set up the problem/environment # xdot = Ax + Bu # u = - Kx # cost = xQx + uRu + 2xNu A = jp.eye(2) B = jp.eye(2) Q = jp.eye(2) R = jp.eye(2) N = jp.zeros((2, 2)) # rngA, rngB, rngQ, rngR, rng = random.split(rng, 5) # # A = random.normal(rngA, (2, 2)) # A = -1 * random_psd(rngA, 2) # B = random.normal(rngB, (2, 2)) # Q = random_psd(rngQ, 2) + 0.1 * jp.eye(2) # R = random_psd(rngR, 2) + 0.1 * jp.eye(2) # N = jp.zeros((2, 2)) # x_dim, u_dim = B.shape dynamics_fn = lambda x, u: A @ x + B @ u cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u ### Solve the Riccatti equation to get the infinite-horizon optimal solution. K, _, _ = control.lqr(A, B, Q, R, N) K = jp.array(K) t0 = time.time() rng_eval, rng = random.split(rng) x0_eval = random.normal(rng_eval, (1000, 2)) opt_all_costs = vmap(lambda x0: policy_integrate_cost( dynamics_fn, cost_fn, lambda _, x: -K @ x, gamma) (None, x0, total_secs))(x0_eval) opt_cost = jp.mean(opt_all_costs) print(f"opt_cost = {opt_cost} in {time.time() - t0}s") ### Set up the learned policy model. policy_init, policy = stax.serial( Dense(64), Relu, Dense(64), Relu, Dense(2), ) # policy_init, policy = DenseNoBias(2) rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (2, )) cost_and_grad = jit( value_and_grad( policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma))) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) def multiple_steps(num_steps): """Return a jit-able function that runs `num_steps` iterations.""" def body(_, stuff): rng, _, opt = stuff rng_x0, rng = random.split(rng) x0 = random.normal(rng_x0, (2, )) cost, g = cost_and_grad(opt.value, x0, total_secs) # Gradient clipping # g = tree_map(lambda x: jp.clip(x, -10, 10), g) # g = optimizers.clip_grads(g, 64) return rng, cost, opt.update(g) return lambda rng, opt: lax.fori_loop(0, num_steps, body, (rng, jp.zeros(()), opt)) multi_steps = 1 run = jit(multiple_steps(multi_steps)) ### Main optimization loop. costs = [] for i in range(25000): t0 = time.time() rng, cost, opt = run(rng, opt) print(f"Episode {(i + 1) * multi_steps}:") print(f" excess cost = {cost - opt_cost}") print(f" elapsed = {time.time() - t0}") costs.append(float(cost)) print(f"Opt solution cost from starting point: {opt_cost}") # print(f"Gradient at opt solution: {opt_g}") # Print the identified and optimal policy. Note that layers multiply multipy # on the right instead of the left so we need a transpose. print(f"Est solution parameters: {opt.value}") print(f"Opt solution parameters: {K.T}") est_all_costs = vmap( lambda x0: policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma) (opt.value, x0, total_secs))(x0_eval) ### Scatter plot of learned policy performance vs optimal policy performance. plt.figure() plt.scatter(est_all_costs, opt_all_costs) plt.plot([-100, 100], [-100, 100], color="gray") plt.xlim(0, jp.max(est_all_costs)) plt.ylim(0, jp.max(opt_all_costs)) plt.xlabel("Learned policy cost") plt.ylabel("Optimal cost") plt.title("Performance relative to the direct LQR solution") ### Plot performance per iteration, incl. average optimal policy performance. plt.figure() plt.plot(costs) plt.axhline(opt_cost, linestyle="--", color="gray") plt.yscale("log") plt.xlabel("Iteration") plt.ylabel(f"Cost (T = {total_secs}s)") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("ODE control of LQR problem") ### Example rollout plots (learned policy vs optimal policy). x0 = jp.array([1.0, 2.0]) framerate = 30 timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate)) est_policy_rollout_states = ode.odeint( lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps) est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))( est_policy_rollout_states) opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x), y0=x0, t=timesteps) opt_policy_rollout_controls = vmap(lambda x: -K @ x)( opt_policy_rollout_states) plt.figure() plt.plot(est_policy_rollout_states[:, 0], est_policy_rollout_states[:, 1], marker='.') plt.plot(opt_policy_rollout_states[:, 0], opt_policy_rollout_states[:, 1], marker='.') plt.xlabel("x_1") plt.ylabel("x_2") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("Phase space trajectory") plt.figure() plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2, axis=-1))) plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2, axis=-1))) plt.xlabel("time") plt.ylabel("control input (L2 norm)") plt.legend(["Learned policy", "Direct LQR solution"]) plt.title("Policy control over time") ### Plot quiver field showing dynamics under learned policy. plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x)) plt.show()
def multinomial(key, p, n, shape=()): n_max = int(jnp.max(n)) return _multinomial(key, p, n, n_max, shape)
def train_epoch(self, evaluate=True): """Train one PPO epoch.""" epoch_start_time = time.time() # Evaluate the policy. policy_eval_start_time = time.time() if evaluate and (self._epoch + 1) % self._eval_every_n == 0: self._rng, key = jax_random.split(self._rng, num=2) self.evaluate() policy_eval_time = ppo.get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "PPO epoch [% 6d]: collecting trajectories.", self._epoch) self._rng, key = jax_random.split(self._rng) trajs, _, timing_info, self._model_state = self.collect_trajectories( train=True, temperature=1.0) trajs = [(t[0], t[1], t[2], t[4]) for t in trajs] self._should_reset = False trajectory_collection_time = ppo.get_time( trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) rewards = np.array([np.sum(traj[2]) for traj in trajs]) avg_reward = np.mean(rewards) std_reward = np.std(rewards) max_reward = np.max(rewards) min_reward = np.min(rewards) self._train_sw.scalar("train/reward_mean_truncated", avg_reward, step=self._epoch) if evaluate and not self._separate_eval: metrics = {"raw": {1.0: {"mean": avg_reward, "std": std_reward}}} ppo.write_eval_reward_summaries(metrics, self._eval_sw, self._epoch) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards, padded_infos) = ppo.pad_trajectories(trajs, boundary=self._boundary) padding_time = ppo.get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", ppo.get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) if padded_actions.ndim == 2: # Add control axis. padded_actions = np.expand_dims(padded_actions, axis=-1) # Some assertions. B, T, C = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert ((B, T + 1) + self.train_env.observation_space.shape == padded_observations.shape) log_prob_recompute_start_time = time.time() assert ("log_prob_actions" in padded_infos and "value_predictions" in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. actual_log_probabs_traj = padded_infos["log_prob_actions"] actual_value_predictions_traj = padded_infos["value_predictions"] assert (B, T, C) == actual_log_probabs_traj.shape[:3] A = actual_log_probabs_traj.shape[3] # pylint: disable=invalid-name assert (B, T, 1) == actual_value_predictions_traj.shape # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with # (B, T, C, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. self._rng, key = jax_random.split(self._rng) log_probabs_traj, value_predictions_traj, self._model_state, _ = ( self._get_predictions(padded_observations, self._model_state, rng=key)) assert (B, T + 1, C, A) == log_probabs_traj.shape assert (B, T + 1, 1) == value_predictions_traj.shape # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. log_probabs_traj = np.concatenate( (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) value_predictions_traj = np.concatenate( (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), axis=1) log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time) # Compute value and ppo losses. self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() (cur_combined_loss, component_losses, summaries, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_params, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=self._gamma, lambda_=self._lambda_, c1=self._c1, c2=self._c2, state=self._model_state, rng=key1)) loss_compute_time = ppo.get_time(loss_compute_start_time) (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus, ppo.get_time(loss_compute_start_time)) self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=self._n_optimizer_steps) opt_step = 0 for key in keys: k1, k2, k3 = jax_random.split(key, num=3) t = time.time() # Update the optimizer state. self._policy_and_value_opt_state, self._model_state = ( ppo.policy_and_value_opt_step( # We pass the optimizer slots between PPO epochs, so we need to # pass the optimization step as well, so for example the # bias-correction in Adam is calculated properly. Alternatively we # could reset the slots and the step in every PPO epoch, but then # the moment estimates in adaptive optimizers would never have # enough time to warm up. So it makes sense to reuse the slots, # even though we're optimizing a different loss in every new # epoch. self._total_opt_step, self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params, self._policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=self._c1, c2=self._c2, gamma=self._gamma, lambda_=self._lambda_, state=self._model_state, rng=k1)) opt_step += 1 self._total_opt_step += 1 # Compute the approx KL for early stopping. (log_probab_actions_new, _), self._model_state = (self._policy_and_value_net_apply( padded_observations, self._policy_and_value_net_params, self._model_state, rng=k2)) approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = approx_kl > 1.5 * self._target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization after %d steps, " "with approx_kl: %0.2f", opt_step, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (opt_step % self._print_every_optimizer_steps == 0 or opt_step == self._n_optimizer_steps or early_stopping): # Compute and log the loss. (combined_loss, component_losses, _, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_params, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=self._gamma, lambda_=self._lambda_, c1=self._c1, c2=self._c2, state=self._model_state, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", ppo.get_time(t, t2)) (ppo_loss, value_loss, entropy_bonus) = component_losses logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, combined_loss, ppo_loss, value_loss, entropy_bonus) if early_stopping: break optimization_time = ppo.get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss))) summaries.update({ "n_optimizer_steps": opt_step, "approx_kl": approx_kl, }) for (name, value) in summaries.items(): self._train_sw.scalar("train/{}".format(name), value, step=self._epoch) logging.info( "PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", self._epoch, min_reward, max_reward, avg_reward, combined_loss, ppo_loss, value_loss, entropy_bonus) # Bump the epoch counter before saving a checkpoint, so that a call to # save() after the training loop is a no-op if a checkpoint was saved last # epoch - otherwise it would bump the epoch counter on the checkpoint. last_epoch = self._epoch self._epoch += 1 # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. policy_save_start_time = time.time() # TODO(afrozm): Refactor to trax.save_state. if (self._n_trajectories_done >= self._done_frac_for_policy_save * self.train_env.batch_size and self._epoch % self._save_every_n == 0) or self._async_mode: self.save() policy_save_time = ppo.get_time(policy_save_start_time) epoch_time = ppo.get_time(epoch_start_time) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): self._timing_sw.scalar("timing/%s" % k, v, step=last_epoch) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("PPO epoch [% 6d], Timings: \n%s", last_epoch, "\n".join(timing_info_list)) # Flush summary writers once in a while. if self._epoch % 1000 == 0: self.flush_summaries()
def _estimate_cell_capacity(R, box_size, cell_size, buffer_size_multiplier): # TODO(schsam): We might want to do something more sophisticated here or at # least expose this constant. spatial_dim = R.shape[-1] cell_capacity = np.max(count_cell_filling(R, box_size, cell_size)) return int(cell_capacity * buffer_size_multiplier)
def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape n_max = jnp.max(self.n).item() return _random_binomial(rng_key, self.p, self.n, n_max, shape)
def tree_max(tree): return np.max(tree_flatten(tree_map(lambda arr: np.max(arr), tree))[0])
def pairwise(fn, metric, species=None, reduce_axis=None, keepdims=False, **kwargs): """Promotes a function that acts on a pair to one that acts on a set. Args: fn: A function that takes an ndarray of pairwise distances or displacements of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying parameters for the function. fn returns an ndarray of evaluations of shape [n, m, d_out]. metric: A function that takes two ndarray of positions of shape [n, spatial_dimension] and [m, spatial_dimension] respectively and returns an ndarray of distances or displacements of shape [n, m, d_in]. The metric can optionally take a floating point time as a third argument. species: A list of species for the different particles. This should either be None (in which case it is assumed that all the particles have the same species), an integer ndarray of shape [n] with species data, or Dynamic in which case the species data will be specified dynamically. Note: that dynamic species specification is less efficient, because we cannot specialize shape information. reduce_axis: A list of axes to reduce over. This is supplied to np.sum and so the same convention is used. keepdims: A boolean specifying whether the empty dimensions should be kept upon reduction. This is supplied to np.sum and so the same convention is used. kwargs: Arguments providing parameters to the mapped function. In cases where no species information is provided these should be either 1) a scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n]. If species information is provided then the parameters should be specified as either 1) a scalar or 2) an ndarray of shape [max_species, max_species]. Returns: A function fn_mapped. If species is None or statically specified then fn_mapped takes as arguments an ndarray of positions of shape [n, spatial_dimension]. If species is Dynamic then fn_mapped takes as input an ndarray of shape [n, spatial_dimension], an integer ndarray of species of shape [n], and an integer specifying the maximum species. The mapped function can also optionally take keyword arguments that get threaded through the metric. """ if species is None: kwargs = _kwargs_to_parameters(species, **kwargs) def fn_mapped(R, **dynamic_kwargs): dr = metric(R, R, **dynamic_kwargs) # NOTE(schsam): Currently we place a diagonal mask no matter what function # we are mapping. Should this be an option? return _high_precision_sum(_diagonal_mask(fn(dr, **kwargs)), axis=reduce_axis, keepdims=keepdims) * f32(0.5) elif isinstance(species, np.ndarray): _check_species_dtype(species) species_count = int(np.max(species)) if reduce_axis is not None or keepdims: # TODO(schsam): Support reduce_axis with static species. raise ValueError def fn_mapped(R, **dynamic_kwargs): U = f32(0.0) for i in range(species_count + 1): for j in range(i, species_count + 1): s_kwargs = _kwargs_to_parameters((i, j), **kwargs) Ra = R[species == i] Rb = R[species == j] dr = metric(Ra, Rb, **dynamic_kwargs) if j == i: dU = _high_precision_sum( _diagonal_mask(fn(dr, **s_kwargs))) U = U + f32(0.5) * dU else: dU = _high_precision_sum(fn(dr, **s_kwargs)) U = U + dU return U elif species is quantity.Dynamic: def fn_mapped(R, species, species_count, **dynamic_kwargs): _check_species_dtype(species) U = f32(0.0) N = R.shape[0] dr = metric(R, R, **dynamic_kwargs) for i in range(species_count): for j in range(species_count): s_kwargs = _kwargs_to_parameters((i, j), **kwargs) mask_a = np.array(np.reshape(species == i, (N, )), dtype=R.dtype) mask_b = np.array(np.reshape(species == j, (N, )), dtype=R.dtype) mask = mask_a[:, np.newaxis] * mask_b[np.newaxis, :] if i == j: mask = mask * _diagonal_mask(mask) dU = mask * fn(dr, **s_kwargs) U = U + _high_precision_sum( dU, axis=reduce_axis, keepdims=keepdims) return U / f32(2.0) else: raise ValueError( 'Species must be None, an ndarray, or Dynamic. Found {}.'.format( species)) return fn_mapped
def grid(fn, box_size, minimum_cell_size, cell_capacity_or_example_positions, species=None, separate_build_and_apply=False, cells_per_iter=-1): r"""Returns a function that evaluates a function sparsely on a grid. Suppose f is a function of positions, f:R^{N\times D}\to R^{N\times M} such that f does not depend on particle pairs that are separated by at least a cutoff \sigma. It is efficient to compute f by evaluating it separately over cells of a spatial partition of the system into of a grid whose side-length is given by \sigma. This function does this spatially partitioned evaluation for a wide range of functions fn. This is accomplished by composing two functions. First a function build_cells creates the spatial partition, then a function compute applies fn to each cell using JAX's autobatching and then copies the result back to an ndarray of shape [particle_count, output_dimension]. We also support the option to return the two functions separately. The grid is constructed so that each cell contains not only those particles in the given cell, but also those particles in a "halo" around the cell. Currently, we let the halo size be the same as the grid size so that each grid cell contains particles from neighboring cells. This is for easy grid construction, but future optimization might be to allow for different halo sizes. Since XLA requires that shapes be statically specified, we allocate a fixed sized buffer for each cell. The size of this buffer can either be specified manually or it can be estimated automatically from a set of positions. Note, if the structure of a system is changing significantly during the course of dynamics (e.g. during minimization) it is probably worth adjusting the buffer size over the course of the dynamics. Currently, the function must have a signature fn(R, species, species_count). It would be nice in the future to support functions with more general signature. TODO. This partitioning will likely form the groundwork for parallelizing simulations over different accelerators. Args: fn: A function that we would like to compute over the partition. Should take arguments R (an ndarray of floating point positions of shape [particle_count, spatial_dimension]), species (an ndarry of integer species of shape [particle_count], species count (an integer specifying the maximum species number). fn should return an ndarray of shape [particle_count, output_dimension] (where output_dimension can be 1, but should be present). box_size: A float or an ndarray of shape [spatial_dimension] specifying the size of the system. Note, this code is written for the case where the boundaries are periodic. If this is not the case, then the current code will be slightly less efficient. cell_size: A float specifying the side length of each cell. cell_capacity_or_example_positions: Either an integer specifying the size number of particles that can be stored in each cell or an ndarray of positions of shape [particle_count, spatial_dimension] that is used to estimate the cell_capacity. species: Either an ndarray of integers of shape [particle_count] with the species type of each particle or None, in which case it is assumed that all particles have the same species. separate_build_and_apply: A boolean specifying whether or not we would like to compose the build_cells and compute functions. cells_per_iter: Depending on the size of the system, it might be necessary to apply fn over batches of cells. cells_per_iter is an integer specifying the number of cells per batch. If cells_per_iter is -1 then all cells are computed together. Returns: If separate_build_and_apply is False then returns a single function fn_mapped that takes an ndarray of shape [particle_count, spatial_dimension] as well as optional kwargs and returns an ndarray of shape [particle_count, output_dimension]. If separate_build_and_apply is True then returns two functions. A build_cells function that takes an ndarray of positions of shape [particle_count, spatial_dimension] and returns a Grid. It also returns a function compute that takes a Grid and computes fn over the grid. """ fn = vmap(fn, ( 0, 0, ), 0) if species is None: species_count = 1 else: species_count = int(np.max(species) + 1) cell_capacity = cell_capacity_or_example_positions if _is_variable_compatible_with_positions(cell_capacity): cell_capacity = _estimate_cell_capacity(cell_capacity, box_size, minimum_cell_size) elif not isinstance(cell_capacity, int): msg = ( 'cell_capacity_or_example_positions must either be an integer ' 'specifying the cell capacity or a set of positions that will be used ' 'to estimate a cell capacity. Found {}.'.format( type(cell_capacity))) raise ValueError(msg) def build_cells(R): N = R.shape[0] dim = R.shape[1] neighborhood_tile_count = 3**dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if species is None: _species = np.zeros((N, ), dtype=i32) else: _species = species hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create grid data. particle_id = lax.iota(np.int64, N) # NOTE(schsam): We use the convention that particles that come from the # center cell have their true id copied, whereas particles that come from # the halo have an id = N. Then when we copy data back from the grid, # we copy it to an array of shape [N + 1, output_dimension] and then # truncate it to an array of shape [N, output_dimension] which ignores the # halo particles. mask_id = np.ones((N, ), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) # NOTE(schsam): empty_species_index is just supposed to be large enough that # we will never run into it. However, there might be a more robust way to do # this. empty_species_index = i16(1000) cell_species = empty_species_index * np.ones( (cell_count * cell_capacity, 1), dtype=_species.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) indices = np.array(R / cell_size, dtype=i32) # Create a copy of particle data for each neighboring cell shifting the hash # appropriately. # TODO(schsam): Replace with np.tile() when it gets implemented. tiled_R = R tiled_species = _species for _ in range(neighborhood_tile_count - 1): tiled_R = np.concatenate((tiled_R, R), axis=0) tiled_species = np.concatenate((tiled_species, _species), axis=0) tiled_hash = np.array([], dtype=i32) tiled_id = np.array([], dtype=i32) for dindex in _neighboring_cells(dim): tiled_indices = np.mod(indices + dindex, cells_per_side) tiled_hash = np.concatenate( (tiled_hash, np.sum(tiled_indices * hash_multipliers, axis=1)), axis=0) if np.all(dindex == 0): tiled_id = np.concatenate((tiled_id, particle_id), axis=0) else: tiled_id = np.concatenate((tiled_id, mask_id), axis=0) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(tiled_hash) sorted_R = tiled_R[sort_map] sorted_species = tiled_species[sort_map] sorted_hash = tiled_hash[sort_map] sorted_id = tiled_id[sort_map] tiled_size = neighborhood_tile_count * N sorted_cell_id = np.mod(lax.iota(np.int64, tiled_size), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id def copy_values_to_cell(cell_value, value, ids): scatter_indices = np.reshape(ids, (tiled_size, 1)) dnums = lax.ScatterDimensionNumbers( update_window_dims=tuple([1]), inserted_window_dims=tuple([0]), scatter_dims_to_operand_dims=tuple([0]), ) return lax.scatter(cell_value, scatter_indices, value, dnums) cell_R = copy_values_to_cell(cell_R, sorted_R, sorted_cell_id) sorted_species = np.reshape(sorted_species, (tiled_size, 1)) cell_species = copy_values_to_cell(cell_species, sorted_species, sorted_cell_id) sorted_id = np.reshape(sorted_id, (tiled_size, 1)) cell_id = copy_values_to_cell(cell_id, sorted_id, sorted_cell_id) cell_R = np.reshape(cell_R, (cell_count, cell_capacity, dim)) cell_species = np.reshape(cell_species, (cell_count, cell_capacity)) cell_id = np.reshape(cell_id, (cell_count, cell_capacity)) return Grid(N, dim, cell_count, cell_R, cell_species, cell_id) def compute(cell_data, **kwargs): N, dim, cell_count, cell_R, cell_species, cell_id, = cell_data cell_output_shape = _grid_trace_shape(fn, cell_R, cell_species, species_count=species_count) output_dimension = cell_output_shape[-1] _cells_per_iter = cells_per_iter if cells_per_iter == -1: _cells_per_iter = cell_count def copy_values_from_cell(value, cell_value, cell_id): scatter_indices = np.reshape(cell_id, (_cells_per_iter * cell_capacity, 1)) cell_value = np.reshape( cell_value, (_cells_per_iter * cell_capacity, output_dimension)) dnums = lax.ScatterDimensionNumbers( update_window_dims=tuple([1]), inserted_window_dims=tuple([0]), scatter_dims_to_operand_dims=tuple([0]), ) return lax.scatter(value, scatter_indices, cell_value, dnums) def compute_cell_block(start, value): start = _cells_per_iter * start compute_R = lax.dynamic_slice( cell_R, (start, 0, 0), (_cells_per_iter, cell_capacity, dim)) compute_species = lax.dynamic_slice( cell_species, (start, 0), (_cells_per_iter, cell_capacity)) compute_id = lax.dynamic_slice(cell_id, (start, 0), (_cells_per_iter, cell_capacity)) cell_value = fn(compute_R, compute_species, species_count=species_count, **kwargs) return copy_values_from_cell(value, cell_value, compute_id) value = np.zeros((N + 1, output_dimension), dtype=cell_R.dtype) if cells_per_iter > 0: return lax.fori_loop( 0, int(math.ceil(float(cell_count) / cells_per_iter)), compute_cell_block, value)[:N] else: return compute_cell_block(0, value)[:N] if separate_build_and_apply: return build_cells, compute else: return lambda R, **kwargs: compute(build_cells(R), **kwargs)
## NONLINEAR LEAST-SQUARES CLASS ***************************************************************************** nlls = NllsClass(xi, L, IC, tol=tol, maxIter=maxIter, timer=True) data = pickle.load(open('data/EOL_IC.pickle', 'rb')) sol = { 'loss': onp.zeros((data['R0'].shape[0])), 'it': onp.zeros((data['R0'].shape[0])), 'time': onp.zeros((data['R0'].shape[0])) } ## RUN TEST ************************************************************************************************* for i in tqdm.trange(data['R0'].shape[0]): R0 = data['R0'][i, :] V0 = data['V0'][i, :] ## scale initial conditons pscale = np.max(np.abs(R0)) tscale = pscale / np.max(np.abs(V0)) xi = TFCDictRobust({'xis':onp.zeros((Hs(z).shape[1],3)),\ 'xic':onp.array([0.5 * (V0 - R0), -0.5*(V0 + R0)]),\ 'b':np.sqrt(10.)*onp.ones(1)}) IC['R0'] = R0 / pscale IC['V0'] = V0 * tscale / pscale IC['ag'] = np.array([0., 0., -1.62]) * tscale**2 / pscale xi, it, time = nlls.run(xi, IC) sol['loss'][i] = np.max(np.abs(L(xi, IC))) sol['it'][i] = it sol['time'][i] = time
def loss_fn( output_logits, targets, valid_mask, num_nodes, captured, negative_example_weight=1, focal_loss_gamma=0.0, ): """Compute loss and single-batch metrics for some outputs. Args: output_logits: Binary logits produced by the model. targets: Model targets. valid_mask: Mask determining which outputs are valid. num_nodes: How many nodes there are in each example. captured: Ignored negative_example_weight: Weight to assign to a negative example when computing the loss. Positive examples always get weight 1. focal_loss_gamma: Focusing parameter for the focal loss, as described in Lin et al. (2018). If zero, uses standard cross-entropy loss. Returns: Tuple (loss, metrics_dict). """ del captured num_targets = jnp.count_nonzero(targets) # Compute cross entropy. unmasked_nll = model_util.binary_logit_cross_entropy( output_logits, targets) if focal_loss_gamma: # (1-p_correct)**gamma = (-(p-1))**gamma = (-expm1(log(p)))**gamma focus_term = jnp.power(-jnp.expm1(-unmasked_nll), focal_loss_gamma) unmasked_nll = unmasked_nll * focus_term # Mask the results so that they only count nodes that exist. masked_nll = unmasked_nll * valid_mask # Primary loss: Sum of nll over all nodes. We use sum because most of the # edges are easy negatives. positive_nll = jnp.sum( jnp.where(targets, masked_nll, jnp.zeros_like(masked_nll))) negative_nll = jnp.sum( jnp.where(targets, jnp.zeros_like(masked_nll), masked_nll)) reweighted_nll = positive_nll + negative_example_weight * negative_nll binary_nll = jnp.sum(reweighted_nll) # Compute additional metrics to track learning progress. # Average NLL of target edges: avg_nll_per_target = positive_nll / num_targets # Average NLL of non-target edges: num_non_targets = num_nodes**2 - num_targets avg_nll_per_non_target = negative_nll / num_non_targets # Max error for any edge prediction: worst_nll = jnp.max(masked_nll) loss = binary_nll # Ratio of positive to negative targets. If this is equal to # negative_example_weight, the positive and negative examples will have the # same total weight. positive_per_negative = num_targets / num_non_targets # Precision and recall at 0.1 threshold thresholded_preds = output_logits > jax.scipy.special.logit(0.1) count_target_pred = jnp.count_nonzero(thresholded_preds & targets) count_pred = jnp.count_nonzero(thresholded_preds & valid_mask.astype(bool)) precision = count_target_pred / count_pred recall = count_target_pred / num_targets return loss, { "avg_per_target": avg_nll_per_target, "avg_per_non_target": avg_nll_per_non_target, "worst": worst_nll, "positive_per_negative": positive_per_negative, "effective_p_model_given_target": jnp.exp(-avg_nll_per_target), "effective_p_model_given_nontarget": 1 - jnp.exp(-avg_nll_per_non_target), "batch_clf_thresh_at_0.1/precision": precision, "batch_clf_thresh_at_0.1/recall": recall, "batch_clf_thresh_at_0.1/f1": 2 * (precision * recall) / (precision + recall), }
def pytree_func(params, x): return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0)
def maxlike(y, x, model, params0, batch_size=4092, epochs=3, learning_rate=0.5, step=1e-4, output=None): # compute derivatives fg0_fun = jax.value_and_grad(model) g0_fun = jax.grad(model) h0_fun = jax.hessian(model) # generate functions fg_fun = jax.jit(fg0_fun) g_fun = jax.jit(g0_fun) h_fun = jax.jit(h0_fun) # construct dataset N, K = len(y), len(params0) data = DataLoader(y, x, batch_size) # initialize params params = params0.copy() # do training for ep in range(epochs): # epoch stats agg_loss, agg_batch = 0.0, 0 # iterate over batches for y_bat, x_bat in data: # compute gradients loss, diff = fg_fun(params, y_bat, x_bat) # compute step step = -learning_rate * diff params += step # error gain = np.dot(step, diff) move = np.max(np.abs(gain)) # compute statistics agg_loss += loss agg_batch += 1 # display stats avg_loss = agg_loss / agg_batch print(f'{ep:3}: loss = {avg_loss}') # return to device if output == 'beta': return params.copy(), None try: # get hessian matrix hess = np.zeros((K, K)) for y_bat, x_bat in data: hess += h_fun(params, y_bat, x_bat) hess *= batch_size / N except Exception as e: # our gods have failed us print(e) # source of error print('Falling back to finite difference for hessian') hess_rows = [np.zeros(K) for i in range(K)] diff = step * np.eye(K) for y_bat, x_bat in data: g0_batch = g_fun(params, y_bat, x_bat)[None, :] for i in range(K): params1 = params + diff[i, :] hess_rows[i] += g_fun(params1, y_bat, x_bat) - g0_batch hess = np.vstack(hess_rows) * (batch_size / N) / step # get cov matrix sigma = np.linalg.inv(hess) / N # return all return params.copy(), sigma.copy()
def makeInputs(OMap, r_cent, contrasts, X, Y, gridsizedeg=4, gridperdeg=5, AngWidth=32): ''' makes the input arrays for the various stimulus conditions all radii at the highest contrast - to test Surround Suppression all contrasts at the highest radius - to test contrast effect highest contrast and radius with a Gabor filter - to test Ray-Maunsell Effect OMap = orientation preference across the cortex r_cent = array of stimulus radii contrasts = array of stimulus contrasts X,Y = matrices of distances in X and Y in degrees various parameters of the network Outputs StimConds = array of dim Ne x stimCondition (the name is short for Stimulus Conditions) stimCondition = [max radii * varying contrasts, max contrast * vary radii, Gabor] ''' rads = np.hstack( (np.max(r_cent) * np.ones(len(contrasts) - 1), r_cent )) # cause I don't want to double up the Contrast = 100 condition Contrasts = np.hstack( (contrasts, np.ones(len(r_cent)) * np.max(contrasts)) ) # need to add one for Gabor condition, but I would subtract one to not double up the C= 100 R= max condition gridsize = OMap.shape dx = gridsizedeg / gridsize[0] # dx is degrees between neurons Mid1 = int(np.floor(gridsize[0] / 2)) Mid2 = int(np.floor(gridsize[1] / 2)) # Python does linear indexing weird, just going to use the found midpts # trgt = onp.ravel_multi_index((Mid1, Mid2), (Len[0], Len[1])) Orientation = OMap[Mid1, Mid2] dOri = np.abs(OMap - Orientation) dOri = np.where(dOri > 90, 180 - dOri, dOri) In0 = np.ravel(np.exp(-dOri**2 / (2 * AngWidth**2))) RFdecay = 0.8 / 2 #biologic decay is 0.8 mm, magfactor =2 mm/deg RFdecay = RFdecay / 10 #trying to find good SI, this parameter has a large impact on the suppression curve RFdecay = 0.04 #RFdecay = dx #GaborSigma = 0.3*np.max(r_cent) GaborSigma = 0.5 x0 = X[Mid1, Mid2] y0 = Y[Mid1, Mid2] x_space = X - x0 y_space = Y - y0 # find the distances across the cortex r_space = np.ravel(np.sqrt(x_space**2 + y_space**2)) #find the spatial input for a constant grating InSr = (1 - (1 / (1 + np.exp(-(r_space - rads[:, None]) / RFdecay)))) #find the spatial input for a Gabor InGabor = np.exp(-r_space**2 / 2 / GaborSigma**2) #include the contrasts with it if len(contrasts) > 1: StimConds = Contrasts[:, None] * np.vstack((InSr, InGabor)) else: StimConds = Contrasts[:, None] * InSr StimConds = StimConds * In0 #include the relative drive between E and I cells -- nixing this cause gE and gI are parametrs #InSpace = np.hstack( (StimConds, gI*StimConds)).T #.T makes it neurons by stimcond #array to reference to find max contrasts, etc stimulus_condition = np.vstack((Contrasts, np.hstack( (rads, np.max(rads))))) return StimConds.T, stimulus_condition, InSr
def cond(val): return np.all(np.array([ np.max(np.abs(L(z,val['xi'],val['xp']))) > tol, val['it'] < 30, np.max(np.abs(val['dxi'])) > tol]))
def train_epoch(self, evaluate=True): """Train one PPO epoch.""" epoch_start_time = time.time() # Evaluate the policy. policy_eval_start_time = time.time() if evaluate and (self._epoch + 1) % self._eval_every_n == 0: self._rng, key = jax_random.split(self._rng, num=2) self.evaluate() policy_eval_time = ppo.get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, 'PPO epoch [% 6d]: collecting trajectories.', self._epoch) self._rng, key = jax_random.split(self._rng) trajs, _, timing_info, self._model_state = self.collect_trajectories( train=True, temperature=1.0) trajs = [(t[0], t[1], t[2], t[4]) for t in trajs] self._should_reset = False trajectory_collection_time = ppo.get_time( trajectory_collection_start_time) logging.vlog(1, 'Collecting trajectories took %0.2f msec.', trajectory_collection_time) rewards = np.array([np.sum(traj[2]) for traj in trajs]) avg_reward = np.mean(rewards) std_reward = np.std(rewards) max_reward = np.max(rewards) min_reward = np.min(rewards) self._log('train', 'train/reward_mean_truncated', avg_reward) if evaluate and not self._separate_eval: metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}} ppo.write_eval_reward_summaries(metrics, self._log, self._epoch) logging.vlog(1, 'Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s', avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, 'Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]', float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, 'Trajectory Lengths: %s', [len(traj[0]) for traj in trajs]) preprocessing_start_time = time.time() (padded_observations, padded_actions, padded_rewards, reward_mask, padded_infos) = self._preprocess_trajectories(trajs) preprocessing_time = ppo.get_time(preprocessing_start_time) logging.vlog(1, 'Preprocessing trajectories took %0.2f msec.', ppo.get_time(preprocessing_start_time)) logging.vlog(1, 'Padded Observations\' shape [%s]', str(padded_observations.shape)) logging.vlog(1, 'Padded Actions\' shape [%s]', str(padded_actions.shape)) logging.vlog(1, 'Padded Rewards\' shape [%s]', str(padded_rewards.shape)) # Some assertions. B, RT = padded_rewards.shape # pylint: disable=invalid-name B, AT = padded_actions.shape # pylint: disable=invalid-name assert (B, RT) == reward_mask.shape assert B == padded_observations.shape[0] log_prob_recompute_start_time = time.time() # TODO(pkozakowski): The following commented out code collects the network # predictions made while stepping the environment and uses them in PPO # training, so that we can use non-deterministic networks (e.g. with # dropout). This does not work well with serialization, so instead we # recompute all network predictions. Let's figure out a solution that will # work with both serialized sequences and non-deterministic networks. # assert ('log_prob_actions' in padded_infos and # 'value_predictions' in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. # actual_log_probabs_traj = padded_infos['log_prob_actions'] # actual_value_predictions_traj = padded_infos['value_predictions'] # assert (B, T, C) == actual_log_probabs_traj.shape[:3] # A = actual_log_probabs_traj.shape[3] # pylint: disable=invalid-name # assert (B, T, 1) == actual_value_predictions_traj.shape del padded_infos # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with # (B, T, C, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. self._rng, key = jax_random.split(self._rng) (log_probabs_traj, value_predictions_traj) = (self._policy_and_value_net_apply( padded_observations, weights=self._policy_and_value_net_weights, state=self._model_state, rng=key, )) assert (B, AT) == log_probabs_traj.shape[:2] assert (B, AT) == value_predictions_traj.shape # TODO(pkozakowski): Commented out for the same reason as before. # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. # log_probabs_traj = np.concatenate( # (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) # value_predictions_traj = np.concatenate( # (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), # axis=1) log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time) # Compute value and ppo losses. self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(2, 'Starting to compute P&V loss.') loss_compute_start_time = time.time() (cur_combined_loss, component_losses, summaries, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_weights, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, self._rewards_to_actions, padded_rewards, reward_mask, nontrainable_params=self._nontrainable_params, state=self._model_state, rng=key1)) loss_compute_time = ppo.get_time(loss_compute_start_time) (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses logging.vlog( 1, 'Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.', cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus, ppo.get_time(loss_compute_start_time)) self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(1, 'Policy and Value Optimization') optimization_start_time = time.time() keys = jax_random.split(key1, num=self._n_optimizer_steps) opt_step = 0 opt_batch_size = min(self._optimizer_batch_size, B) index_batches = ppo.shuffled_index_batches(dataset_size=B, batch_size=opt_batch_size) for (index_batch, key) in zip(index_batches, keys): k1, k2, k3 = jax_random.split(key, num=3) t = time.time() # Update the optimizer state on the sampled minibatch. self._policy_and_value_opt_state, self._model_state = ( ppo.policy_and_value_opt_step( # We pass the optimizer slots between PPO epochs, so we need to # pass the optimization step as well, so for example the # bias-correction in Adam is calculated properly. Alternatively we # could reset the slots and the step in every PPO epoch, but then # the moment estimates in adaptive optimizers would never have # enough time to warm up. So it makes sense to reuse the slots, # even though we're optimizing a different loss in every new # epoch. self._total_opt_step, self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params, self._policy_and_value_net_apply, log_probabs_traj[index_batch], value_predictions_traj[index_batch], padded_observations[index_batch], padded_actions[index_batch], self._rewards_to_actions, padded_rewards[index_batch], reward_mask[index_batch], nontrainable_params=self._nontrainable_params, state=self._model_state, rng=k1)) opt_step += 1 self._total_opt_step += 1 # Compute the approx KL for early stopping. Use the whole dataset - as we # only do inference, it should fit in the memory. (log_probab_actions_new, _) = (self._policy_and_value_net_apply( padded_observations, weights=self._policy_and_value_net_weights, state=self._model_state, rng=k2)) action_mask = np.dot(np.pad(reward_mask, ((0, 0), (0, 1))), self._rewards_to_actions) approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj, action_mask) early_stopping = approx_kl > 1.5 * self._target_kl if early_stopping: logging.vlog( 1, 'Early stopping policy and value optimization after %d steps, ' 'with approx_kl: %0.2f', opt_step, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (opt_step % self._print_every_optimizer_steps == 0 or opt_step == self._n_optimizer_steps or early_stopping): # Compute and log the loss. (combined_loss, component_losses, _, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_weights, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, self._rewards_to_actions, padded_rewards, reward_mask, nontrainable_params=self._nontrainable_params, state=self._model_state, rng=k3)) logging.vlog( 1, 'One Policy and Value grad desc took: %0.2f msec', ppo.get_time(t, t2)) (ppo_loss, value_loss, entropy_bonus) = component_losses logging.vlog( 1, 'Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->' ' [%10.2f(%10.2f,%10.2f,%10.2f)]', cur_combined_loss, combined_loss, ppo_loss, value_loss, entropy_bonus) if early_stopping: break optimization_time = ppo.get_time(optimization_start_time) logging.vlog( 1, 'Total Combined Loss reduction [%0.2f]%%', (100 * (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss))) summaries.update({ 'n_optimizer_steps': opt_step, 'approx_kl': approx_kl, }) for (name, value) in summaries.items(): self._log('train', 'train/{}'.format(name), value) logging.info( 'PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined' ' Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]', self._epoch, min_reward, max_reward, avg_reward, combined_loss, ppo_loss, value_loss, entropy_bonus) # Bump the epoch counter before saving a checkpoint, so that a call to # save() after the training loop is a no-op if a checkpoint was saved last # epoch - otherwise it would bump the epoch counter on the checkpoint. last_epoch = self._epoch self._epoch += 1 # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. policy_save_start_time = time.time() # TODO(afrozm): Refactor to trax.save_trainer_state. if (self._n_trajectories_done >= self._done_frac_for_policy_save * self.train_env.batch_size and self._epoch % self._save_every_n == 0) or self._async_mode: self.save() policy_save_time = ppo.get_time(policy_save_start_time) epoch_time = ppo.get_time(epoch_start_time) timing_dict = { 'epoch': epoch_time, 'policy_eval': policy_eval_time, 'trajectory_collection': trajectory_collection_time, 'preprocessing': preprocessing_time, 'log_prob_recompute': log_prob_recompute_time, 'loss_compute': loss_compute_time, 'optimization': optimization_time, 'policy_save': policy_save_time, } timing_dict.update(timing_info) if self._should_write_summaries: for k, v in timing_dict.items(): self._timing_sw.scalar('timing/%s' % k, v, step=last_epoch) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ '%s : % 10.2f' % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info('PPO epoch [% 6d], Timings: \n%s', last_epoch, '\n'.join(timing_info_list)) # Flush summary writers once in a while. if self._epoch % 1000 == 0: self.flush_summaries()
def max(a, axis=None, keepdims=None, initial=None, where=None): if isinstance(a, JaxArray): a = a.value r = jnp.max(a, axis=axis, keepdims=keepdims, initial=initial, where=where) return r if axis is None else JaxArray(r)
def tei_array(geom, basis): """ Build two electron integral array from a jax.numpy array of the cartesian geometry in Bohr, and a basis dictionary as defined by basis_utils.build_basis_set We have to loop over primitives rather than shells because JAX needs intermediates to be consistent sizes in order to compile. """ # Smush primitive data together into vectors coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis) nbf = get_nbf(basis) max_am = jnp.max(ams) max_am_idx = max_am * 4 + 1 #TODO add excpetion raise if angular momentum is too high B_vals = jnp.zeros(4*max_am+1) nprim = coeffs.shape[0] # Obtain all possible primitive quartet index combinations primitive_quartets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim)) #print("Number of basis functions: ", nbf) #print("Number of primitve quartets: ", primitive_quartets.shape[0]) #TODO Experimental: precompute quantities and lookup inside loop # Compute all possible Gaussian products for this basis set aa_plus_bb = jnp.broadcast_to(exps, (nprim,nprim)) + jnp.transpose(jnp.broadcast_to(exps, (nprim,nprim)), (1,0)) aa_times_A = jnp.einsum('i,ij->ij', exps, geom[atoms]) aaxA_plus_bbxB = aa_times_A[:,None,:] + aa_times_A[None,:,:] gaussian_products = jnp.einsum('ijk,ij->ijk', aaxA_plus_bbxB, 1/aa_plus_bb) # Compute all rab2 (rcd2), every possible jnp.dot(A-B,A-B) natom = geom.shape[0] tmpA = jnp.broadcast_to(geom, (natom,natom,3)) AminusB = (tmpA - jnp.transpose(tmpA, (1,0,2))) AmBdot = jnp.einsum('ijk,ijk->ij', AminusB, AminusB) # shape: (natom,natom) # Compute all differences between gaussian product centers with all atom centers tmpP = jnp.tile(gaussian_products, natom).reshape(nprim,nprim,natom,3) PminusA = tmpP - jnp.broadcast_to(geom, tmpP.shape) # Commpute all powers (up to max_am) of differences between gaussian product centers and atom centers # Shape: (nprim, nprim, natom, 3, max_am+1). In loop index PA_pow as [p1,p2,atoms[p1],:,:] PminusA_pow = jnp.power(jnp.transpose(jnp.broadcast_to(PminusA, (max_am+1,nprim,nprim,natom,3)), (1,2,3,4,0)), jnp.arange(max_am+1)) with loops.Scope() as s: s.G = jnp.zeros((nbf,nbf,nbf,nbf)) s.a = 0 # center A angular momentum iterator s.b = 0 # center B angular momentum iterator s.c = 0 # center C angular momentum iterator s.d = 0 # center D angular momentum iterator # Loop over primitive quartets, compute integral, add to appropriate index in G for prim_quar in s.range(primitive_quartets.shape[0]): # Load in primitive indices, coeffs, exponents, centers, angular momentum index, and leading placement index in TEI array p1,p2,p3,p4 = primitive_quartets[prim_quar] coef = coeffs[p1] * coeffs[p2] * coeffs[p3] * coeffs[p4] aa, bb, cc, dd = exps[p1], exps[p2], exps[p3], exps[p4] ld1, ld2, ld3, ld4 = am_leading_indices[ams[p1]],am_leading_indices[ams[p2]],am_leading_indices[ams[p3]],am_leading_indices[ams[p4]] idx1, idx2, idx3, idx4 = indices[p1],indices[p2],indices[p3],indices[p4], #A, B, C, D = geom[atoms[p1]], geom[atoms[p2]], geom[atoms[p3]], geom[atoms[p4]] # Compute common intermediates before looping over AM distributions. # Avoids redundant recomputations/reassignment for all classes other than (ss|ss). #AB = A - B #CD = C - D #rab2 = jnp.dot(AB,AB) #rcd2 = jnp.dot(CD,CD) #P = (aa * A + bb * B) / gamma1 #Q = (cc * C + dd * D) / gamma2 gamma1 = aa + bb gamma2 = cc + dd #TODO P = gaussian_products[p1,p2] Q = gaussian_products[p3,p4] rab2 = AmBdot[atoms[p1],atoms[p2]] rcd2 = AmBdot[atoms[p3],atoms[p4]] #PA = PminusA[p1,p2,atoms[p1]] #PB = PminusA[p1,p2,atoms[p2]] #QC = PminusA[p3,p4,atoms[p3]] #QD = PminusA[p3,p4,atoms[p4]] #TODO PQ = P - Q rpq2 = jnp.dot(PQ,PQ) delta = 0.25*(1/gamma1+1/gamma2) boys_arg = 0.25 * rpq2 / delta boys_eval = boys(jnp.arange(max_am_idx), boys_arg) # Need all powers of Pi-Ai,Pi-Bi,Qi-Ci,Qi-Di (i=x,y,z) up to max_am and Qi-Pi up to max_am_idx # note: this computes unncessary quantities for lower angular momentum, # but avoids repeated computation of the same quantities in loops for higher angular momentum #PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+1,3)).T, jnp.arange(max_am+1)) #PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+1,3)).T, jnp.arange(max_am+1)) #QC_pow = jnp.power(jnp.broadcast_to(Q-C, (max_am+1,3)).T, jnp.arange(max_am+1)) #QD_pow = jnp.power(jnp.broadcast_to(Q-D, (max_am+1,3)).T, jnp.arange(max_am+1)) PA_pow = PminusA_pow[p1,p2,atoms[p1],:,:] PB_pow = PminusA_pow[p1,p2,atoms[p2],:,:] QC_pow = PminusA_pow[p3,p4,atoms[p3],:,:] QD_pow = PminusA_pow[p3,p4,atoms[p4],:,:] QP_pow = jnp.power(jnp.broadcast_to(Q-P, (max_am_idx,3)).T, jnp.arange(max_am_idx)) # Gamma powers are negative, up to -(l1+l2). # Make array such that the given negative index returns the same negative power. g1_pow = jnp.power(4*gamma1, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) g2_pow = jnp.power(4*gamma2, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) oodelta_pow = jnp.power(1 / delta, jnp.arange(max_am_idx)) # l1 + l2 + l3 + l4 + 1 prefactor = 34.986836655249726 / (gamma1*gamma2*jnp.sqrt(gamma1+gamma2)) \ * jnp.exp(-aa*bb*rab2/gamma1 + -cc*dd*rcd2/gamma2) * coef # TODO is there symmetry here? s.a = 0 for _ in s.while_range(lambda: s.a < dims[p1]): s.b = 0 for _ in s.while_range(lambda: s.b < dims[p2]): s.c = 0 for _ in s.while_range(lambda: s.c < dims[p3]): s.d = 0 for _ in s.while_range(lambda: s.d < dims[p4]): # Collect angular momentum and index in G la, ma, na = angular_momentum_combinations[s.a + ld1] lb, mb, nb = angular_momentum_combinations[s.b + ld2] lc, mc, nc = angular_momentum_combinations[s.c + ld3] ld, md, nd = angular_momentum_combinations[s.d + ld4] i = idx1 + s.a j = idx2 + s.b k = idx3 + s.c l = idx4 + s.d # Compute the primitive quartet tei and add to appropriate index in G Bx = B_array(la,lb,lc,ld,PA_pow[0],PB_pow[0],QC_pow[0],QD_pow[0],QP_pow[0],g1_pow,g2_pow,oodelta_pow,B_vals) By = B_array(ma,mb,mc,md,PA_pow[1],PB_pow[1],QC_pow[1],QD_pow[1],QP_pow[1],g1_pow,g2_pow,oodelta_pow,B_vals) Bz = B_array(na,nb,nc,nd,PA_pow[2],PB_pow[2],QC_pow[2],QD_pow[2],QP_pow[2],g1_pow,g2_pow,oodelta_pow,B_vals) with loops.Scope() as S: S.primitive = 0. S.I = 0 S.J = 0 S.K = 0 for _ in S.while_range(lambda: S.I < la + lb + lc + ld + 1): S.J = 0 tmp = Bx[S.I] for _ in S.while_range(lambda: S.J < ma + mb + mc + md + 1): S.K = 0 tmp *= By[S.J] for _ in S.while_range(lambda: S.K < na + nb + nc + nd + 1): tmp *= Bz[S.K] * boys_eval[S.I + S.J + S.K] S.primitive += tmp S.K += 1 S.J += 1 S.I += 1 tei = prefactor * S.primitive s.G = jax.ops.index_add(s.G, jax.ops.index[i,j,k,l], tei) s.d += 1 s.c += 1 s.b += 1 s.a += 1 return s.G
def _softmax(X): """Compute the softmax of matrix X in a numerically stable way.""" shiftx = X - np.max(X, axis=1).reshape(-1, 1) exps = np.exp(shiftx) return exps / np.sum(exps, axis=1).reshape(-1, 1)
#A = np.identity(5) B = np.array([[1., 0., 0.], [0.5, 0.5, 0.], [0.5, 0.0, 0.5], [0., 1., 0.], [0., 0.5, 0.5]]) A = 0.9 * A / np.linalg.norm(A, ord='nuc') B = 0.9 * B / np.linalg.norm(B, ord='nuc') #W = np.zeros(T_0 + T) W = (np.sin(np.arange((T_0 + T) * m) / (20 * np.pi)).reshape( (T_0 + T), m) @ np.ones((m, n))).reshape((T_0 + T), n, 1) sysid = SystemID() sysid.initialize(n, m) x = x0 for t in range(T_0): u = sysid.get_action(x) x = A @ x + B @ u + W[t] A_id, B_id = sysid.system_id() print("A versus A_id") print(A) print(A_id) print("B versus B_id") print(B) print(B_id) print("max diff for A, B: ", np.max(np.abs(A - A_id)), np.max(np.abs(B - B_id))) print("norm diff for A, B: ", np.linalg.norm(A - A_id), np.linalg.norm(B - B_id))
def tree_max_abs(tree): return np.max( tree_flatten(tree_map(lambda arr: np.max(np.abs(arr)), tree))[0])
L1 = lambda z, xi: ydd1(z, xi) - Pe * yd1(z, xi) L2 = lambda z, xi: ydd2(z, xi) - Pe * yd2(z, xi) L = lambda xi: np.hstack((L1(z, xi), L2(z, xi))) # Create the residual and jacobians xi1 = onp.zeros(H(z).shape[1]) xi2 = onp.zeros(H(z).shape[1]) y = onp.zeros(1) yd = onp.zeros(1) b = onp.ones(1) * np.sqrt(2. / 0.5) xi = TFCDict({'xi1': xi1, 'xi2': xi2, 'y': y, 'yd': yd, 'b': b}) ## SOLVE THE SYSTEM ************************************************* xi, it, time = NLLS(xi, L, timer=True) X = np.hstack((x1(z, xi), x2(z, xi))) Y = np.hstack((y1(z, xi), y2(z, xi))) # p1 = MakePlot(onp.array([['x (m)']]),onp.array([['y (m)']])) # p1.ax[0].plot(X,Y,label='TFC Solution') # p1.ax[0].plot(X,soln(X),label='Analytical Solution') # p1.ax[0].legend() # p1.show() print('{:.2e} & {:.2e} & {:.5f} & {:.2f}'.format(np.max(np.abs(Y - soln(X))), np.max(np.abs(L(xi))), xp(xi)[0].tolist(), time))
def _train_step(self): """Runs a single training step.""" if self._replay.add_count > self.min_replay_history: if self.training_steps % self.update_period == 0: self._sample_from_replay_buffer() if self._replay_scheme == 'prioritized': # The original prioritized experience replay uses a linear exponent # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) # suggested a fixed exponent actually performs better, except on Pong. probs = self.replay_elements['sampling_probabilities'] # Weight the loss by the inverse priorities. loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) loss_weights /= jnp.max(loss_weights) else: loss_weights = jnp.ones( self.replay_elements['state'].shape[0]) self.optimizer, aux_losses = train( self.network_def, self.target_network_params, self.optimizer, self.replay_elements['state'], self.replay_elements['action'], self.replay_elements['next_state'], self.replay_elements['reward'], self.replay_elements['terminal'], loss_weights, self._support, self.cumulative_gamma, self._mico_weight, self._distance_fn) loss = aux_losses.pop('loss') if self._replay_scheme == 'prioritized': # Rainbow and prioritized replay are parametrized by an exponent # alpha, but in both cases it is set to 0.5 - for simplicity's sake we # leave it as is here, using the more direct sqrt(). Taking the square # root "makes sense", as we are dealing with a squared loss. Add a # small nonzero value to the loss to avoid 0 priority items. While # technically this may be okay, setting all items to 0 priority will # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. self._replay.set_priority(self.replay_elements['indices'], jnp.sqrt(loss + 1e-10)) if self._replay_scheme == 'prioritized': probs = self.replay_elements['sampling_probabilities'] loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) loss_weights /= jnp.max(loss_weights) self._replay.set_priority(self.replay_elements['indices'], jnp.sqrt(loss + 1e-10)) loss = loss_weights * loss if self.summary_writer is not None: values = [] for k in aux_losses: values.append( tf.compat.v1.Summary.Value( tag=f'Losses/{k}', simple_value=aux_losses[k])) summary = tf.compat.v1.Summary(value=values) self.summary_writer.add_summary(summary, self.training_steps) if self.training_steps % self.target_update_period == 0: self._sync_weights() self.training_steps += 1
+x[0]*((1.+x[1]**3)*np.exp(-1.)-np.dot(H(np.ones_like(x[0]),x[1]),xi)) u = lambda xi,*x: u1(xi,*x)\ +(1.-x[1])*(x[0]*np.exp(-x[0])-u1(xi,x[0],np.zeros_like(x[1])))\ +x[1]*(np.exp(-x[0])*(x[0]+1.)-u1(xi,x[0],np.ones_like(x[1]))) # Create the residual laplace = lambda xi,*x: egrad(egrad(u,1),1)(xi,*x)+egrad(egrad(u,2),2)(xi,*x) L = lambda xi,*x: laplace(xi,*x)-np.exp(-x[0])*(x[0]-2.+x[1]**3+6.*x[1]) # Calculate the xi values zXi = np.zeros(H(*x).shape[1]) A = jacfwd(L,0)(zXi,*x) B = -L(zXi,*x) xi = np.dot(np.linalg.pinv(A),B) # Calculate the error dark = np.meshgrid(np.linspace(x0[0],xf[0],n),np.linspace(x0[1],xf[1],n)) x = (dark[0].flatten(),dark[1].flatten()) ur = real(*x) ue = u(xi,*x) err = ur-ue testErr[j,k] = np.max(np.abs(err)) # Print results as a table tab = table.SimpleTable(testErr) print(tab) f = open("TfcData.txt","w") f.write(tab) f.close()
def softmax(x, axis=-1): unnormalized = np.exp(x - np.max(x, axis, keepdims=True)) return unnormalized / np.sum(unnormalized, axis, keepdims=True)
def get_particle_lims(particles): """particles is a (n, 2) array of 2d points. return lims (-a, a) st particles fit into square with corners at +- a.""" a = np.max(np.abs(particles)) return (-a, a)
def matrix_inverse_pth_root(mat_g, p, iter_count=100, error_tolerance=1e-6, ridge_epsilon=1e-6): """Computes mat_g^(-1/p), where p is a positive integer. Coupled newton iterations for matrix inverse pth root. Args: mat_g: the symmetric PSD matrix whose power it to be computed p: exponent, for p a positive integer. iter_count: Maximum number of iterations. error_tolerance: Error indicator, useful for early termination. ridge_epsilon: Ridge epsilon added to make the matrix positive definite. Returns: mat_g^(-1/p) """ mat_g_size = mat_g.shape[0] alpha = jnp.asarray(-1.0 / p, _INVERSE_PTH_ROOT_DATA_TYPE) identity = jnp.eye(mat_g_size, dtype=_INVERSE_PTH_ROOT_DATA_TYPE) _, max_ev, _ = power_iter(mat_g) ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) def _unrolled_mat_pow_1(mat_m): """Computes mat_m^1.""" return mat_m def _unrolled_mat_pow_2(mat_m): """Computes mat_m^2.""" return jnp.matmul(mat_m, mat_m, precision=_INVERSE_PTH_ROOT_PRECISION) def _unrolled_mat_pow_4(mat_m): """Computes mat_m^4.""" mat_pow_2 = _unrolled_mat_pow_2(mat_m) return jnp.matmul(mat_pow_2, mat_pow_2, precision=_INVERSE_PTH_ROOT_PRECISION) def _unrolled_mat_pow_8(mat_m): """Computes mat_m^4.""" mat_pow_4 = _unrolled_mat_pow_4(mat_m) return jnp.matmul(mat_pow_4, mat_pow_4, precision=_INVERSE_PTH_ROOT_PRECISION) def mat_power(mat_m, p): """Computes mat_m^p, for p == 1, 2, 4 or 8. Args: mat_m: a square matrix p: a positive integer Returns: mat_m^p """ # We unrolled the loop for performance reasons. exponent = jnp.round(jnp.log2(p)) return lax.switch(jnp.asarray(exponent, jnp.int32), [ _unrolled_mat_pow_1, _unrolled_mat_pow_2, _unrolled_mat_pow_4, _unrolled_mat_pow_8, ], (mat_m)) def _iter_condition(state): (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state error_above_threshold = jnp.logical_and(error > error_tolerance, run_step) return jnp.logical_and(i < iter_count, error_above_threshold) def _iter_body(state): (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state mat_m_i = (1 - alpha) * identity + alpha * mat_m new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=_INVERSE_PTH_ROOT_PRECISION) new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=_INVERSE_PTH_ROOT_PRECISION) new_error = jnp.max(jnp.abs(new_mat_m - identity)) # sometimes error increases after an iteration before decreasing and # converging. 1.2 factor is used to bound the maximal allowed increase. return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2) if mat_g_size == 1: resultant_mat_h = (mat_g + ridge_epsilon)**alpha error = 0 else: damped_mat_g = mat_g + ridge_epsilon * identity z = (1 + p) / (2 * jnp.linalg.norm(damped_mat_g)) new_mat_m_0 = damped_mat_g * z new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * jnp.power(z, 1.0 / p) init_state = tuple( [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( _iter_condition, _iter_body, init_state) error = jnp.max(jnp.abs(mat_m - identity)) is_converged = jnp.asarray(convergence, old_mat_h.dtype) resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h resultant_mat_h = jnp.asarray(resultant_mat_h, mat_g.dtype) return resultant_mat_h, error
val['dxi'].block_until_ready() toc = timer() time += (toc-tic) it += 1 # print(xp) return np.max(np.abs(L(z,val['xi'],val['xp']))).tolist() xp_0 = 0.75 xp = optim.fsolve(fMin, xp_0, xtol=tol,epsfcn=tol) if xp > xpBound: xp = xpBound val['xp'] = xp val = nlls(val) xi = val['xi'] X = np.hstack((x1(z,xp), x2(z,xp))) Y = np.hstack((y1(z,xi,xp), y2(z,xi,xp) )) # p1 = MakePlot(onp.array([['x (m)']]),onp.array([['y (m)']])) # p1.ax[0].plot(X,Y,label='TFC Solution') # p1.ax[0].plot(X,soln(X),label='Analytical Solution') # p1.ax[0].legend() # p1.show() print('{:.2e} & {:.2e} & {:.5f} & {:.2f}'.format(np.max(np.abs(Y - soln(X))), np.max(np.abs(L(z,xi,xp))), xp, time ))
def minmax(x): x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2) x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2) x0 = jnp.minimum(x0, x[Ellipsis, -2:-1]) x1 = jnp.maximum(x1, x[Ellipsis, 1:2]) return x0, x1
def apply(self, x, communication=Communication.NONE, train=True): """Forward pass.""" batch_size = x.shape[0] if communication is Communication.SQUEEZE_EXCITE_X: x = sample_patches.SqueezeExciteLayer(x) # end if squeeze excite x d1 = nn.relu( nn.Conv(x, 128, kernel_size=(3, 3), strides=(1, 1), bias=True, name="down1")) d2 = nn.relu( nn.Conv(d1, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down2")) d3 = nn.relu( nn.Conv(d2, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down3")) if communication is Communication.SQUEEZE_EXCITE_D: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) num_channels = d_together.shape[-1] y = d_together.mean(axis=1) y = nn.Dense(y, features=num_channels // 4, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) d_together = d_together * y[:, None, :] # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) elif communication is Communication.TRANSFORMER: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) positional_encodings = self.param( "scale_ratio_position_encodings", shape=(1, ) + d_together.shape[1:], initializer=jax.nn.initializers.normal(1. / d_together.shape[-1])) d_together = transformer.Transformer(d_together + positional_encodings, num_layers=2, num_heads=8, is_training=train) # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) t1 = nn.Conv(d1, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy1") t2 = nn.Conv(d2, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy2") t3 = nn.Conv(d3, 9, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy3") raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) + jnp.split(t3, 9, axis=-1)) # The following is for normalization. t = jnp.concatenate((jnp.reshape( t1, [batch_size, -1]), jnp.reshape( t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])), axis=1) t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1]) t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1]) normalized_scores = zeroone(raw_scores, t_min, t_max) stats = { "scores": normalized_scores, "raw_scores": t, } # removes the split dimension. scores are now b x h' x w' shaped normalized_scores = [s.squeeze(-1) for s in normalized_scores] return normalized_scores, stats