def sinkhorn_dual_solver(a,
                         b,
                         cost,
                         epsilon = 1e-2,
                         num_iters = 1000):
  """Runs Sinkhorn algorithm in log space for stability at small epsilon.

  Uses a fixed number of iterations to enable compatibility with jit.
  TODO(riannevdberg): implement dymanic num_iterations with jax control flows.

  Returns the regularized transport cost as in dual formaul 4.30 of
  https://arxiv.org/abs/1803.00567

  Note that the unregularized transport cost for a finite iteration solution of
  the potentials f and g is a lower bound on the regularized is a lower bound
  on the regurized transport cost with a converged solution for f and g (see
  proposition 4.8 of arxiv paper).

  If you need to compute the coupling P^L from the potentials
  after a finite number of iterations L, you should put
  the coupling through a rounding operation (round_coupling) to ensure that
  the coupling is a doubly stochastic. Otherwise <C, P^L> is not a valid
  approximation of L_C(a, b). After applying the rounding_coupling function
  to P^L <-- round(P^L) you can then use < C, P^L> as a valid approximation.

  Args:
   a: np.ndarray<float>[n]: discrete probability distribution. The rows of the
     coupling matrix P must sum up to this vector. If a represents an empirical
     distribution with n samples, all entries should be equal to 1/n.
   b: np.ndarray<float>[m]: discrete probability distribution. The columns of
     the coupling matrix P must sum up to this vector. If b represents an
     empirical distribution with m samples, all entries should be equal to 1/m.
   cost: np.ndarray<float>[n, m]: the cost matrix cost[i,j] = c(x_i, y_j) where
     x_i and y_j are samples from distributions a and b respectively.
   epsilon: (float) the level of entropic regularization wanted.
   num_iters: (int32) the number of Sinkhorn iterations.

  Returns:
   transportation cost (eq. 4.48 of paper), coupling (which needs to be rounded
     still with round_coupling method if used to compute a loss or if you want
     to ensure it has the correct marginals a and b, error in column marginal.
  """
  loga = jnp.expand_dims(jnp.log(a), axis=1)
  logb = jnp.expand_dims(jnp.log(b), axis=0)
  f = jnp.zeros_like(loga)  # epsilon * log_u
  g = jnp.zeros_like(logb)  # epsilon * log_v

  for _ in range(num_iters):
    # Note: If the update order is g before f, then check the error in b,
    # as this will be the largest error. If using the reverse error, then
    # check the error in a.

    # To carry out the logsumexp in a stable way we use the fact that
    # the matrix f + g - cost has all negative entries. We therefore use this
    # to add and subtract f and g in the respective updates in and outside the
    # logsumexp.

    g = epsilon * logb - epsilon * jax.scipy.special.logsumexp(
        (f + g - cost) / epsilon, axis=0, keepdims=True) + g
    f = epsilon * loga - epsilon * jax.scipy.special.logsumexp(
        (f + g - cost) / epsilon, axis=1, keepdims=True) + f

  # Compute error
  coupling = jnp.exp((f + g - cost) / epsilon)
  b_target = jnp.sum(coupling, axis=0)
  err = jnp.max(jnp.abs(b_target - b) / b, axis=None)

  # Compute unregularized cost according to eq. 4.48 of paper.
  # Note that if you want to compute the regularized cost of eq. 4.30
  # this only requires subtracting epsilon, as the double sum
  # < e^f/eps, K e^g/eps > = 1 for updates like in this sinkhorn algorithm.
  transport_cost = jnp.sum(f * a) + jnp.sum(g * b)
  return transport_cost, coupling, err
예제 #2
0
def main():
    total_secs = 10.0
    gamma = 0.9
    rng = random.PRNGKey(0)

    ### Set up the problem/environment
    # xdot = Ax + Bu
    # u = - Kx
    # cost = xQx + uRu + 2xNu

    A = jp.eye(2)
    B = jp.eye(2)
    Q = jp.eye(2)
    R = jp.eye(2)
    N = jp.zeros((2, 2))

    # rngA, rngB, rngQ, rngR, rng = random.split(rng, 5)
    # # A = random.normal(rngA, (2, 2))
    # A = -1 * random_psd(rngA, 2)
    # B = random.normal(rngB, (2, 2))
    # Q = random_psd(rngQ, 2) + 0.1 * jp.eye(2)
    # R = random_psd(rngR, 2) + 0.1 * jp.eye(2)
    # N = jp.zeros((2, 2))

    # x_dim, u_dim = B.shape

    dynamics_fn = lambda x, u: A @ x + B @ u
    cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u

    ### Solve the Riccatti equation to get the infinite-horizon optimal solution.
    K, _, _ = control.lqr(A, B, Q, R, N)
    K = jp.array(K)

    t0 = time.time()
    rng_eval, rng = random.split(rng)
    x0_eval = random.normal(rng_eval, (1000, 2))
    opt_all_costs = vmap(lambda x0: policy_integrate_cost(
        dynamics_fn, cost_fn, lambda _, x: -K @ x, gamma)
                         (None, x0, total_secs))(x0_eval)
    opt_cost = jp.mean(opt_all_costs)
    print(f"opt_cost = {opt_cost} in {time.time() - t0}s")

    ### Set up the learned policy model.
    policy_init, policy = stax.serial(
        Dense(64),
        Relu,
        Dense(64),
        Relu,
        Dense(2),
    )
    # policy_init, policy = DenseNoBias(2)

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (2, ))

    cost_and_grad = jit(
        value_and_grad(
            policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma)))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)

    def multiple_steps(num_steps):
        """Return a jit-able function that runs `num_steps` iterations."""
        def body(_, stuff):
            rng, _, opt = stuff
            rng_x0, rng = random.split(rng)
            x0 = random.normal(rng_x0, (2, ))
            cost, g = cost_and_grad(opt.value, x0, total_secs)

            # Gradient clipping
            # g = tree_map(lambda x: jp.clip(x, -10, 10), g)
            # g = optimizers.clip_grads(g, 64)

            return rng, cost, opt.update(g)

        return lambda rng, opt: lax.fori_loop(0, num_steps, body,
                                              (rng, jp.zeros(()), opt))

    multi_steps = 1
    run = jit(multiple_steps(multi_steps))

    ### Main optimization loop.
    costs = []
    for i in range(25000):
        t0 = time.time()
        rng, cost, opt = run(rng, opt)
        print(f"Episode {(i + 1) * multi_steps}:")
        print(f"    excess cost = {cost - opt_cost}")
        print(f"    elapsed = {time.time() - t0}")
        costs.append(float(cost))

    print(f"Opt solution cost from starting point: {opt_cost}")
    # print(f"Gradient at opt solution: {opt_g}")

    # Print the identified and optimal policy. Note that layers multiply multipy
    # on the right instead of the left so we need a transpose.
    print(f"Est solution parameters: {opt.value}")
    print(f"Opt solution parameters: {K.T}")

    est_all_costs = vmap(
        lambda x0: policy_integrate_cost(dynamics_fn, cost_fn, policy, gamma)
        (opt.value, x0, total_secs))(x0_eval)

    ### Scatter plot of learned policy performance vs optimal policy performance.
    plt.figure()
    plt.scatter(est_all_costs, opt_all_costs)
    plt.plot([-100, 100], [-100, 100], color="gray")
    plt.xlim(0, jp.max(est_all_costs))
    plt.ylim(0, jp.max(opt_all_costs))
    plt.xlabel("Learned policy cost")
    plt.ylabel("Optimal cost")
    plt.title("Performance relative to the direct LQR solution")

    ### Plot performance per iteration, incl. average optimal policy performance.
    plt.figure()
    plt.plot(costs)
    plt.axhline(opt_cost, linestyle="--", color="gray")
    plt.yscale("log")
    plt.xlabel("Iteration")
    plt.ylabel(f"Cost (T = {total_secs}s)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("ODE control of LQR problem")

    ### Example rollout plots (learned policy vs optimal policy).
    x0 = jp.array([1.0, 2.0])
    framerate = 30
    timesteps = jp.linspace(0, total_secs, num=int(total_secs * framerate))
    est_policy_rollout_states = ode.odeint(
        lambda x, _: dynamics_fn(x, policy(opt.value, x)), y0=x0, t=timesteps)
    est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))(
        est_policy_rollout_states)

    opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x),
                                           y0=x0,
                                           t=timesteps)
    opt_policy_rollout_controls = vmap(lambda x: -K @ x)(
        opt_policy_rollout_states)

    plt.figure()
    plt.plot(est_policy_rollout_states[:, 0],
             est_policy_rollout_states[:, 1],
             marker='.')
    plt.plot(opt_policy_rollout_states[:, 0],
             opt_policy_rollout_states[:, 1],
             marker='.')
    plt.xlabel("x_1")
    plt.ylabel("x_2")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Phase space trajectory")

    plt.figure()
    plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2,
                                       axis=-1)))
    plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2,
                                       axis=-1)))
    plt.xlabel("time")
    plt.ylabel("control input (L2 norm)")
    plt.legend(["Learned policy", "Direct LQR solution"])
    plt.title("Policy control over time")

    ### Plot quiver field showing dynamics under learned policy.
    plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x))

    plt.show()
예제 #3
0
파일: util.py 프로젝트: lumip/numpyro
def multinomial(key, p, n, shape=()):
    n_max = int(jnp.max(n))
    return _multinomial(key, p, n, n_max, shape)
예제 #4
0
    def train_epoch(self, evaluate=True):
        """Train one PPO epoch."""
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self._epoch + 1) % self._eval_every_n == 0:
            self._rng, key = jax_random.split(self._rng, num=2)
            self.evaluate()

        policy_eval_time = ppo.get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "PPO epoch [% 6d]: collecting trajectories.",
                     self._epoch)
        self._rng, key = jax_random.split(self._rng)
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0)
        trajs = [(t[0], t[1], t[2], t[4]) for t in trajs]
        self._should_reset = False
        trajectory_collection_time = ppo.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        rewards = np.array([np.sum(traj[2]) for traj in trajs])
        avg_reward = np.mean(rewards)
        std_reward = np.std(rewards)
        max_reward = np.max(rewards)
        min_reward = np.min(rewards)

        self._train_sw.scalar("train/reward_mean_truncated",
                              avg_reward,
                              step=self._epoch)
        if evaluate and not self._separate_eval:
            metrics = {"raw": {1.0: {"mean": avg_reward, "std": std_reward}}}
            ppo.write_eval_reward_summaries(metrics, self._eval_sw,
                                            self._epoch)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions, padded_rewards,
         padded_infos) = ppo.pad_trajectories(trajs, boundary=self._boundary)
        padding_time = ppo.get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     ppo.get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        if padded_actions.ndim == 2:
            # Add control axis.
            padded_actions = np.expand_dims(padded_actions, axis=-1)

        # Some assertions.
        B, T, C = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert ((B, T + 1) + self.train_env.observation_space.shape ==
                padded_observations.shape)

        log_prob_recompute_start_time = time.time()
        assert ("log_prob_actions" in padded_infos
                and "value_predictions" in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        actual_log_probabs_traj = padded_infos["log_prob_actions"]
        actual_value_predictions_traj = padded_infos["value_predictions"]

        assert (B, T, C) == actual_log_probabs_traj.shape[:3]
        A = actual_log_probabs_traj.shape[3]  # pylint: disable=invalid-name
        assert (B, T, 1) == actual_value_predictions_traj.shape

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with
        # (B, T, C, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        self._rng, key = jax_random.split(self._rng)

        log_probabs_traj, value_predictions_traj, self._model_state, _ = (
            self._get_predictions(padded_observations,
                                  self._model_state,
                                  rng=key))

        assert (B, T + 1, C, A) == log_probabs_traj.shape
        assert (B, T + 1, 1) == value_predictions_traj.shape

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        log_probabs_traj = np.concatenate(
            (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        value_predictions_traj = np.concatenate(
            (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
            axis=1)

        log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time)

        # Compute value and ppo losses.
        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        (cur_combined_loss, component_losses, summaries,
         self._model_state) = (ppo.combined_loss(
             self._policy_and_value_net_params,
             log_probabs_traj,
             value_predictions_traj,
             self._policy_and_value_net_apply,
             padded_observations,
             padded_actions,
             padded_rewards,
             reward_mask,
             gamma=self._gamma,
             lambda_=self._lambda_,
             c1=self._c1,
             c2=self._c2,
             state=self._model_state,
             rng=key1))
        loss_compute_time = ppo.get_time(loss_compute_start_time)
        (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus,
            ppo.get_time(loss_compute_start_time))

        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=self._n_optimizer_steps)
        opt_step = 0
        for key in keys:
            k1, k2, k3 = jax_random.split(key, num=3)
            t = time.time()
            # Update the optimizer state.
            self._policy_and_value_opt_state, self._model_state = (
                ppo.policy_and_value_opt_step(
                    # We pass the optimizer slots between PPO epochs, so we need to
                    # pass the optimization step as well, so for example the
                    # bias-correction in Adam is calculated properly. Alternatively we
                    # could reset the slots and the step in every PPO epoch, but then
                    # the moment estimates in adaptive optimizers would never have
                    # enough time to warm up. So it makes sense to reuse the slots,
                    # even though we're optimizing a different loss in every new
                    # epoch.
                    self._total_opt_step,
                    self._policy_and_value_opt_state,
                    self._policy_and_value_opt_update,
                    self._policy_and_value_get_params,
                    self._policy_and_value_net_apply,
                    log_probabs_traj,
                    value_predictions_traj,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    c1=self._c1,
                    c2=self._c2,
                    gamma=self._gamma,
                    lambda_=self._lambda_,
                    state=self._model_state,
                    rng=k1))
            opt_step += 1
            self._total_opt_step += 1

            # Compute the approx KL for early stopping.
            (log_probab_actions_new,
             _), self._model_state = (self._policy_and_value_net_apply(
                 padded_observations,
                 self._policy_and_value_net_params,
                 self._model_state,
                 rng=k2))

            approx_kl = ppo.approximate_kl(log_probab_actions_new,
                                           log_probabs_traj, reward_mask)

            early_stopping = approx_kl > 1.5 * self._target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization after %d steps, "
                    "with approx_kl: %0.2f", opt_step, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (opt_step % self._print_every_optimizer_steps == 0
                    or opt_step == self._n_optimizer_steps or early_stopping):
                # Compute and log the loss.
                (combined_loss, component_losses, _,
                 self._model_state) = (ppo.combined_loss(
                     self._policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     self._policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=self._gamma,
                     lambda_=self._lambda_,
                     c1=self._c1,
                     c2=self._c2,
                     state=self._model_state,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    ppo.get_time(t, t2))
                (ppo_loss, value_loss, entropy_bonus) = component_losses
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    combined_loss, ppo_loss, value_loss, entropy_bonus)

            if early_stopping:
                break

        optimization_time = ppo.get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss)))

        summaries.update({
            "n_optimizer_steps": opt_step,
            "approx_kl": approx_kl,
        })
        for (name, value) in summaries.items():
            self._train_sw.scalar("train/{}".format(name),
                                  value,
                                  step=self._epoch)

        logging.info(
            "PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]",
            self._epoch, min_reward, max_reward, avg_reward, combined_loss,
            ppo_loss, value_loss, entropy_bonus)

        # Bump the epoch counter before saving a checkpoint, so that a call to
        # save() after the training loop is a no-op if a checkpoint was saved last
        # epoch - otherwise it would bump the epoch counter on the checkpoint.
        last_epoch = self._epoch
        self._epoch += 1

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        policy_save_start_time = time.time()
        # TODO(afrozm): Refactor to trax.save_state.
        if (self._n_trajectories_done >=
                self._done_frac_for_policy_save * self.train_env.batch_size
                and self._epoch % self._save_every_n == 0) or self._async_mode:
            self.save()
        policy_save_time = ppo.get_time(policy_save_start_time)

        epoch_time = ppo.get_time(epoch_start_time)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            self._timing_sw.scalar("timing/%s" % k, v, step=last_epoch)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("PPO epoch [% 6d], Timings: \n%s", last_epoch,
                     "\n".join(timing_info_list))

        # Flush summary writers once in a while.
        if self._epoch % 1000 == 0:
            self.flush_summaries()
예제 #5
0
def _estimate_cell_capacity(R, box_size, cell_size, buffer_size_multiplier):
    # TODO(schsam): We might want to do something more sophisticated here or at
    # least expose this constant.
    spatial_dim = R.shape[-1]
    cell_capacity = np.max(count_cell_filling(R, box_size, cell_size))
    return int(cell_capacity * buffer_size_multiplier)
예제 #6
0
파일: binomial.py 프로젝트: tblazina/mcx
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     n_max = jnp.max(self.n).item()
     return _random_binomial(rng_key, self.p, self.n, n_max, shape)
예제 #7
0
def tree_max(tree):
    return np.max(tree_flatten(tree_map(lambda arr: np.max(arr), tree))[0])
예제 #8
0
파일: smap.py 프로젝트: berkonat/jax-md
def pairwise(fn,
             metric,
             species=None,
             reduce_axis=None,
             keepdims=False,
             **kwargs):
    """Promotes a function that acts on a pair to one that acts on a set.

  Args:
    fn: A function that takes an ndarray of pairwise distances or displacements
      of shape [n, m] or [n, m, d_in] respectively as well as kwargs specifying
      parameters for the function. fn returns an ndarray of evaluations of shape
      [n, m, d_out].
    metric: A function that takes two ndarray of positions of shape
      [n, spatial_dimension] and [m, spatial_dimension] respectively and returns
      an ndarray of distances or displacements of shape [n, m, d_in]. The metric
      can optionally take a floating point time as a third argument.
    species: A list of species for the different particles. This should either
      be None (in which case it is assumed that all the particles have the same
      species), an integer ndarray of shape [n] with species data, or Dynamic
      in which case the species data will be specified dynamically. Note: that
      dynamic species specification is less efficient, because we cannot
      specialize shape information.
    reduce_axis: A list of axes to reduce over. This is supplied to np.sum and
      so the same convention is used.
    keepdims: A boolean specifying whether the empty dimensions should be kept
      upon reduction. This is supplied to np.sum and so the same convention is
      used.
    kwargs: Arguments providing parameters to the mapped function. In cases
      where no species information is provided these should be either 1) a
      scalar, 2) an ndarray of shape [n], 3) an ndarray of shape [n, n]. If
      species information is provided then the parameters should be specified as
      either 1) a scalar or 2) an ndarray of shape [max_species, max_species].

  Returns:
    A function fn_mapped.

    If species is None or statically specified then fn_mapped takes as arguments
    an ndarray of positions of shape [n, spatial_dimension].

    If species is Dynamic then fn_mapped takes as input an ndarray of shape
    [n, spatial_dimension], an integer ndarray of species of shape [n], and an
    integer specifying the maximum species.

    The mapped function can also optionally take keyword arguments that get
    threaded through the metric.
  """
    if species is None:
        kwargs = _kwargs_to_parameters(species, **kwargs)

        def fn_mapped(R, **dynamic_kwargs):
            dr = metric(R, R, **dynamic_kwargs)
            # NOTE(schsam): Currently we place a diagonal mask no matter what function
            # we are mapping. Should this be an option?
            return _high_precision_sum(_diagonal_mask(fn(dr, **kwargs)),
                                       axis=reduce_axis,
                                       keepdims=keepdims) * f32(0.5)
    elif isinstance(species, np.ndarray):
        _check_species_dtype(species)
        species_count = int(np.max(species))
        if reduce_axis is not None or keepdims:
            # TODO(schsam): Support reduce_axis with static species.
            raise ValueError

        def fn_mapped(R, **dynamic_kwargs):
            U = f32(0.0)
            for i in range(species_count + 1):
                for j in range(i, species_count + 1):
                    s_kwargs = _kwargs_to_parameters((i, j), **kwargs)
                    Ra = R[species == i]
                    Rb = R[species == j]
                    dr = metric(Ra, Rb, **dynamic_kwargs)
                    if j == i:
                        dU = _high_precision_sum(
                            _diagonal_mask(fn(dr, **s_kwargs)))
                        U = U + f32(0.5) * dU
                    else:
                        dU = _high_precision_sum(fn(dr, **s_kwargs))
                        U = U + dU
            return U
    elif species is quantity.Dynamic:

        def fn_mapped(R, species, species_count, **dynamic_kwargs):
            _check_species_dtype(species)
            U = f32(0.0)
            N = R.shape[0]
            dr = metric(R, R, **dynamic_kwargs)
            for i in range(species_count):
                for j in range(species_count):
                    s_kwargs = _kwargs_to_parameters((i, j), **kwargs)
                    mask_a = np.array(np.reshape(species == i, (N, )),
                                      dtype=R.dtype)
                    mask_b = np.array(np.reshape(species == j, (N, )),
                                      dtype=R.dtype)
                    mask = mask_a[:, np.newaxis] * mask_b[np.newaxis, :]
                    if i == j:
                        mask = mask * _diagonal_mask(mask)
                    dU = mask * fn(dr, **s_kwargs)
                    U = U + _high_precision_sum(
                        dU, axis=reduce_axis, keepdims=keepdims)
            return U / f32(2.0)
    else:
        raise ValueError(
            'Species must be None, an ndarray, or Dynamic. Found {}.'.format(
                species))
    return fn_mapped
예제 #9
0
파일: smap.py 프로젝트: berkonat/jax-md
def grid(fn,
         box_size,
         minimum_cell_size,
         cell_capacity_or_example_positions,
         species=None,
         separate_build_and_apply=False,
         cells_per_iter=-1):
    r"""Returns a function that evaluates a function sparsely on a grid.

  Suppose f is a function of positions, f:R^{N\times D}\to R^{N\times M} such
  that f does not depend on particle pairs that are separated by at least a
  cutoff \sigma. It is efficient to compute f by evaluating it separately over
  cells of a spatial partition of the system into of a grid whose side-length
  is given by \sigma. This function does this spatially partitioned evaluation
  for a wide range of functions fn.

  This is accomplished by composing two functions. First a function build_cells
  creates the spatial partition, then a function compute applies fn to each cell
  using JAX's autobatching and then copies the result back to an ndarray of
  shape [particle_count, output_dimension]. We also support the option to
  return the two functions separately.

  The grid is constructed so that each cell contains not only those particles
  in the given cell, but also those particles in a "halo" around the cell.
  Currently, we let the halo size be the same as the grid size so that each grid
  cell contains particles from neighboring cells. This is for easy grid
  construction, but future optimization might be to allow for different halo
  sizes.

  Since XLA requires that shapes be statically specified, we allocate a fixed
  sized buffer for each cell. The size of this buffer can either be specified
  manually or it can be estimated automatically from a set of positions. Note,
  if the structure of a system is changing significantly during the course of
  dynamics (e.g. during minimization) it is probably worth adjusting the buffer
  size over the course of the dynamics.

  Currently, the function must have a signature fn(R, species, species_count).
  It would be nice in the future to support functions with more general
  signature. TODO.

  This partitioning will likely form the groundwork for parallelizing
  simulations over different accelerators.

  Args:
    fn: A function that we would like to compute over the partition. Should take
      arguments R (an ndarray of floating point positions of shape
      [particle_count, spatial_dimension]), species (an ndarry of integer
      species of shape [particle_count], species count (an integer specifying
      the maximum species number). fn should return an ndarray of shape
      [particle_count, output_dimension] (where output_dimension can be 1, but
      should be present).
    box_size: A float or an ndarray of shape [spatial_dimension] specifying the
      size of the system. Note, this code is written for the case where the
      boundaries are periodic. If this is not the case, then the current code
      will be slightly less efficient.
    cell_size: A float specifying the side length of each cell.
    cell_capacity_or_example_positions: Either an integer specifying the size
      number of particles that can be stored in each cell or an ndarray of
      positions of shape [particle_count, spatial_dimension] that is used to
      estimate the cell_capacity.
    species: Either an ndarray of integers of shape [particle_count] with the
      species type of each particle or None, in which case it is assumed that
      all particles have the same species.
    separate_build_and_apply: A boolean specifying whether or not we would like
      to compose the build_cells and compute functions.
    cells_per_iter: Depending on the size of the system, it might be necessary
      to apply fn over batches of cells. cells_per_iter is an integer specifying
      the number of cells per batch. If cells_per_iter is -1 then all cells are
      computed together.

  Returns:
    If separate_build_and_apply is False then returns a single function
    fn_mapped that takes an ndarray of shape [particle_count, spatial_dimension]
    as well as optional kwargs and returns an ndarray of shape
    [particle_count, output_dimension].

    If separate_build_and_apply is True then returns two functions. A
    build_cells function that takes an ndarray of positions of shape
    [particle_count, spatial_dimension] and returns a Grid. It also returns a
    function compute that takes a Grid and computes fn over the grid.
  """
    fn = vmap(fn, (
        0,
        0,
    ), 0)

    if species is None:
        species_count = 1
    else:
        species_count = int(np.max(species) + 1)

    cell_capacity = cell_capacity_or_example_positions
    if _is_variable_compatible_with_positions(cell_capacity):
        cell_capacity = _estimate_cell_capacity(cell_capacity, box_size,
                                                minimum_cell_size)
    elif not isinstance(cell_capacity, int):
        msg = (
            'cell_capacity_or_example_positions must either be an integer '
            'specifying the cell capacity or a set of positions that will be used '
            'to estimate a cell capacity. Found {}.'.format(
                type(cell_capacity)))
        raise ValueError(msg)

    def build_cells(R):
        N = R.shape[0]
        dim = R.shape[1]
        neighborhood_tile_count = 3**dim

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if species is None:
            _species = np.zeros((N, ), dtype=i32)
        else:
            _species = species

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create grid data.
        particle_id = lax.iota(np.int64, N)
        # NOTE(schsam): We use the convention that particles that come from the
        # center cell have their true id copied, whereas particles that come from
        # the halo have an id = N. Then when we copy data back from the grid,
        # we copy it to an array of shape [N + 1, output_dimension] and then
        # truncate it to an array of shape [N, output_dimension] which ignores the
        # halo particles.
        mask_id = np.ones((N, ), np.int64) * N
        cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
        # NOTE(schsam): empty_species_index is just supposed to be large enough that
        # we will never run into it. However, there might be a more robust way to do
        # this.
        empty_species_index = i16(1000)
        cell_species = empty_species_index * np.ones(
            (cell_count * cell_capacity, 1), dtype=_species.dtype)
        cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

        indices = np.array(R / cell_size, dtype=i32)

        # Create a copy of particle data for each neighboring cell shifting the hash
        # appropriately.
        # TODO(schsam): Replace with np.tile() when it gets implemented.
        tiled_R = R
        tiled_species = _species
        for _ in range(neighborhood_tile_count - 1):
            tiled_R = np.concatenate((tiled_R, R), axis=0)
            tiled_species = np.concatenate((tiled_species, _species), axis=0)
        tiled_hash = np.array([], dtype=i32)
        tiled_id = np.array([], dtype=i32)

        for dindex in _neighboring_cells(dim):
            tiled_indices = np.mod(indices + dindex, cells_per_side)
            tiled_hash = np.concatenate(
                (tiled_hash, np.sum(tiled_indices * hash_multipliers, axis=1)),
                axis=0)

            if np.all(dindex == 0):
                tiled_id = np.concatenate((tiled_id, particle_id), axis=0)
            else:
                tiled_id = np.concatenate((tiled_id, mask_id), axis=0)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
        # is a flat list that repeats 0, .., cell_capacity. So long as there are
        # fewer than cell_capacity particles per cell, each particle is guarenteed
        # to get a cell id that is unique.
        sort_map = np.argsort(tiled_hash)
        sorted_R = tiled_R[sort_map]
        sorted_species = tiled_species[sort_map]
        sorted_hash = tiled_hash[sort_map]
        sorted_id = tiled_id[sort_map]

        tiled_size = neighborhood_tile_count * N
        sorted_cell_id = np.mod(lax.iota(np.int64, tiled_size), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        def copy_values_to_cell(cell_value, value, ids):
            scatter_indices = np.reshape(ids, (tiled_size, 1))
            dnums = lax.ScatterDimensionNumbers(
                update_window_dims=tuple([1]),
                inserted_window_dims=tuple([0]),
                scatter_dims_to_operand_dims=tuple([0]),
            )
            return lax.scatter(cell_value, scatter_indices, value, dnums)

        cell_R = copy_values_to_cell(cell_R, sorted_R, sorted_cell_id)
        sorted_species = np.reshape(sorted_species, (tiled_size, 1))
        cell_species = copy_values_to_cell(cell_species, sorted_species,
                                           sorted_cell_id)
        sorted_id = np.reshape(sorted_id, (tiled_size, 1))
        cell_id = copy_values_to_cell(cell_id, sorted_id, sorted_cell_id)

        cell_R = np.reshape(cell_R, (cell_count, cell_capacity, dim))
        cell_species = np.reshape(cell_species, (cell_count, cell_capacity))
        cell_id = np.reshape(cell_id, (cell_count, cell_capacity))

        return Grid(N, dim, cell_count, cell_R, cell_species, cell_id)

    def compute(cell_data, **kwargs):
        N, dim, cell_count, cell_R, cell_species, cell_id, = cell_data

        cell_output_shape = _grid_trace_shape(fn,
                                              cell_R,
                                              cell_species,
                                              species_count=species_count)
        output_dimension = cell_output_shape[-1]

        _cells_per_iter = cells_per_iter
        if cells_per_iter == -1:
            _cells_per_iter = cell_count

        def copy_values_from_cell(value, cell_value, cell_id):
            scatter_indices = np.reshape(cell_id,
                                         (_cells_per_iter * cell_capacity, 1))
            cell_value = np.reshape(
                cell_value,
                (_cells_per_iter * cell_capacity, output_dimension))
            dnums = lax.ScatterDimensionNumbers(
                update_window_dims=tuple([1]),
                inserted_window_dims=tuple([0]),
                scatter_dims_to_operand_dims=tuple([0]),
            )
            return lax.scatter(value, scatter_indices, cell_value, dnums)

        def compute_cell_block(start, value):
            start = _cells_per_iter * start

            compute_R = lax.dynamic_slice(
                cell_R, (start, 0, 0), (_cells_per_iter, cell_capacity, dim))
            compute_species = lax.dynamic_slice(
                cell_species, (start, 0), (_cells_per_iter, cell_capacity))
            compute_id = lax.dynamic_slice(cell_id, (start, 0),
                                           (_cells_per_iter, cell_capacity))

            cell_value = fn(compute_R,
                            compute_species,
                            species_count=species_count,
                            **kwargs)
            return copy_values_from_cell(value, cell_value, compute_id)

        value = np.zeros((N + 1, output_dimension), dtype=cell_R.dtype)
        if cells_per_iter > 0:
            return lax.fori_loop(
                0, int(math.ceil(float(cell_count) / cells_per_iter)),
                compute_cell_block, value)[:N]
        else:
            return compute_cell_block(0, value)[:N]

    if separate_build_and_apply:
        return build_cells, compute
    else:
        return lambda R, **kwargs: compute(build_cells(R), **kwargs)
예제 #10
0
## NONLINEAR LEAST-SQUARES CLASS *****************************************************************************
nlls = NllsClass(xi, L, IC, tol=tol, maxIter=maxIter, timer=True)

data = pickle.load(open('data/EOL_IC.pickle', 'rb'))
sol = {
    'loss': onp.zeros((data['R0'].shape[0])),
    'it': onp.zeros((data['R0'].shape[0])),
    'time': onp.zeros((data['R0'].shape[0]))
}
## RUN TEST *************************************************************************************************
for i in tqdm.trange(data['R0'].shape[0]):
    R0 = data['R0'][i, :]
    V0 = data['V0'][i, :]

    ## scale initial conditons
    pscale = np.max(np.abs(R0))
    tscale = pscale / np.max(np.abs(V0))

    xi = TFCDictRobust({'xis':onp.zeros((Hs(z).shape[1],3)),\
                        'xic':onp.array([0.5 * (V0 - R0), -0.5*(V0 + R0)]),\
                        'b':np.sqrt(10.)*onp.ones(1)})

    IC['R0'] = R0 / pscale
    IC['V0'] = V0 * tscale / pscale
    IC['ag'] = np.array([0., 0., -1.62]) * tscale**2 / pscale

    xi, it, time = nlls.run(xi, IC)

    sol['loss'][i] = np.max(np.abs(L(xi, IC)))
    sol['it'][i] = it
    sol['time'][i] = time
def loss_fn(
    output_logits,
    targets,
    valid_mask,
    num_nodes,
    captured,
    negative_example_weight=1,
    focal_loss_gamma=0.0,
):
    """Compute loss and single-batch metrics for some outputs.

  Args:
    output_logits: Binary logits produced by the model.
    targets: Model targets.
    valid_mask: Mask determining which outputs are valid.
    num_nodes: How many nodes there are in each example.
    captured: Ignored
    negative_example_weight: Weight to assign to a negative example when
      computing the loss. Positive examples always get weight 1.
    focal_loss_gamma: Focusing parameter for the focal loss, as described in Lin
      et al. (2018). If zero, uses standard cross-entropy loss.

  Returns:
    Tuple (loss, metrics_dict).
  """
    del captured
    num_targets = jnp.count_nonzero(targets)
    # Compute cross entropy.
    unmasked_nll = model_util.binary_logit_cross_entropy(
        output_logits, targets)
    if focal_loss_gamma:
        # (1-p_correct)**gamma = (-(p-1))**gamma = (-expm1(log(p)))**gamma
        focus_term = jnp.power(-jnp.expm1(-unmasked_nll), focal_loss_gamma)
        unmasked_nll = unmasked_nll * focus_term
    # Mask the results so that they only count nodes that exist.
    masked_nll = unmasked_nll * valid_mask
    # Primary loss: Sum of nll over all nodes. We use sum because most of the
    # edges are easy negatives.
    positive_nll = jnp.sum(
        jnp.where(targets, masked_nll, jnp.zeros_like(masked_nll)))
    negative_nll = jnp.sum(
        jnp.where(targets, jnp.zeros_like(masked_nll), masked_nll))
    reweighted_nll = positive_nll + negative_example_weight * negative_nll
    binary_nll = jnp.sum(reweighted_nll)
    # Compute additional metrics to track learning progress.
    # Average NLL of target edges:
    avg_nll_per_target = positive_nll / num_targets
    # Average NLL of non-target edges:
    num_non_targets = num_nodes**2 - num_targets
    avg_nll_per_non_target = negative_nll / num_non_targets
    # Max error for any edge prediction:
    worst_nll = jnp.max(masked_nll)

    loss = binary_nll

    # Ratio of positive to negative targets. If this is equal to
    # negative_example_weight, the positive and negative examples will have the
    # same total weight.
    positive_per_negative = num_targets / num_non_targets
    # Precision and recall at 0.1 threshold
    thresholded_preds = output_logits > jax.scipy.special.logit(0.1)
    count_target_pred = jnp.count_nonzero(thresholded_preds & targets)
    count_pred = jnp.count_nonzero(thresholded_preds & valid_mask.astype(bool))
    precision = count_target_pred / count_pred
    recall = count_target_pred / num_targets
    return loss, {
        "avg_per_target":
        avg_nll_per_target,
        "avg_per_non_target":
        avg_nll_per_non_target,
        "worst":
        worst_nll,
        "positive_per_negative":
        positive_per_negative,
        "effective_p_model_given_target":
        jnp.exp(-avg_nll_per_target),
        "effective_p_model_given_nontarget":
        1 - jnp.exp(-avg_nll_per_non_target),
        "batch_clf_thresh_at_0.1/precision":
        precision,
        "batch_clf_thresh_at_0.1/recall":
        recall,
        "batch_clf_thresh_at_0.1/f1":
        2 * (precision * recall) / (precision + recall),
    }
예제 #12
0
 def pytree_func(params, x):
     return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0)
예제 #13
0
def maxlike(y,
            x,
            model,
            params0,
            batch_size=4092,
            epochs=3,
            learning_rate=0.5,
            step=1e-4,
            output=None):
    # compute derivatives
    fg0_fun = jax.value_and_grad(model)
    g0_fun = jax.grad(model)
    h0_fun = jax.hessian(model)

    # generate functions
    fg_fun = jax.jit(fg0_fun)
    g_fun = jax.jit(g0_fun)
    h_fun = jax.jit(h0_fun)

    # construct dataset
    N, K = len(y), len(params0)
    data = DataLoader(y, x, batch_size)

    # initialize params
    params = params0.copy()

    # do training
    for ep in range(epochs):
        # epoch stats
        agg_loss, agg_batch = 0.0, 0

        # iterate over batches
        for y_bat, x_bat in data:
            # compute gradients
            loss, diff = fg_fun(params, y_bat, x_bat)

            # compute step
            step = -learning_rate * diff
            params += step

            # error
            gain = np.dot(step, diff)
            move = np.max(np.abs(gain))

            # compute statistics
            agg_loss += loss
            agg_batch += 1

        # display stats
        avg_loss = agg_loss / agg_batch
        print(f'{ep:3}: loss = {avg_loss}')

    # return to device
    if output == 'beta':
        return params.copy(), None

    try:
        # get hessian matrix
        hess = np.zeros((K, K))
        for y_bat, x_bat in data:
            hess += h_fun(params, y_bat, x_bat)
        hess *= batch_size / N
    except Exception as e:
        # our gods have failed us
        print(e)  # source of error
        print('Falling back to finite difference for hessian')
        hess_rows = [np.zeros(K) for i in range(K)]
        diff = step * np.eye(K)
        for y_bat, x_bat in data:
            g0_batch = g_fun(params, y_bat, x_bat)[None, :]
            for i in range(K):
                params1 = params + diff[i, :]
                hess_rows[i] += g_fun(params1, y_bat, x_bat) - g0_batch
        hess = np.vstack(hess_rows) * (batch_size / N) / step

    # get cov matrix
    sigma = np.linalg.inv(hess) / N

    # return all
    return params.copy(), sigma.copy()
def makeInputs(OMap,
               r_cent,
               contrasts,
               X,
               Y,
               gridsizedeg=4,
               gridperdeg=5,
               AngWidth=32):
    '''
    makes the input arrays for the various stimulus conditions
    all radii at the highest contrast - to test Surround Suppression
    all contrasts at the highest radius - to test contrast effect
    highest contrast and radius with a Gabor filter - to test Ray-Maunsell Effect
    
    OMap = orientation preference across the cortex
    r_cent = array of stimulus radii
    contrasts = array of stimulus contrasts
    X,Y = matrices of distances in X and Y in degrees
    various parameters of the network
    
    Outputs
    StimConds = array of dim Ne x stimCondition (the name is short for Stimulus Conditions)
    stimCondition = [max radii * varying contrasts, max contrast * vary radii, Gabor]
    
    '''
    rads = np.hstack(
        (np.max(r_cent) * np.ones(len(contrasts) - 1), r_cent
         ))  # cause I don't want to double up the Contrast = 100 condition
    Contrasts = np.hstack(
        (contrasts, np.ones(len(r_cent)) * np.max(contrasts))
    )  # need to add one for Gabor condition, but I would subtract one to not double up the C= 100 R= max condition

    gridsize = OMap.shape
    dx = gridsizedeg / gridsize[0]  # dx is degrees between neurons

    Mid1 = int(np.floor(gridsize[0] / 2))
    Mid2 = int(np.floor(gridsize[1] / 2))

    # Python does linear indexing weird, just going to use the found midpts
    # trgt = onp.ravel_multi_index((Mid1, Mid2), (Len[0], Len[1]))

    Orientation = OMap[Mid1, Mid2]

    dOri = np.abs(OMap - Orientation)
    dOri = np.where(dOri > 90, 180 - dOri, dOri)
    In0 = np.ravel(np.exp(-dOri**2 / (2 * AngWidth**2)))

    RFdecay = 0.8 / 2  #biologic decay is 0.8 mm, magfactor =2 mm/deg
    RFdecay = RFdecay / 10  #trying to find good SI, this parameter has a large impact on the suppression curve
    RFdecay = 0.04
    #RFdecay = dx
    #GaborSigma = 0.3*np.max(r_cent)
    GaborSigma = 0.5

    x0 = X[Mid1, Mid2]
    y0 = Y[Mid1, Mid2]

    x_space = X - x0
    y_space = Y - y0

    # find the distances across the cortex
    r_space = np.ravel(np.sqrt(x_space**2 + y_space**2))

    #find the spatial input for a constant grating
    InSr = (1 - (1 / (1 + np.exp(-(r_space - rads[:, None]) / RFdecay))))
    #find the spatial input for a Gabor
    InGabor = np.exp(-r_space**2 / 2 / GaborSigma**2)
    #include the contrasts with it
    if len(contrasts) > 1:
        StimConds = Contrasts[:, None] * np.vstack((InSr, InGabor))
    else:
        StimConds = Contrasts[:, None] * InSr
    StimConds = StimConds * In0
    #include the relative drive between E and I cells  -- nixing this cause gE and gI are parametrs
    #InSpace = np.hstack( (StimConds, gI*StimConds)).T #.T makes it neurons by stimcond

    #array to reference to find max contrasts, etc
    stimulus_condition = np.vstack((Contrasts, np.hstack(
        (rads, np.max(rads)))))

    return StimConds.T, stimulus_condition, InSr
예제 #15
0
def cond(val):
    return np.all(np.array([
                np.max(np.abs(L(z,val['xi'],val['xp']))) > tol,
                val['it'] < 30,
                np.max(np.abs(val['dxi'])) > tol]))
예제 #16
0
    def train_epoch(self, evaluate=True):
        """Train one PPO epoch."""
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self._epoch + 1) % self._eval_every_n == 0:
            self._rng, key = jax_random.split(self._rng, num=2)
            self.evaluate()

        policy_eval_time = ppo.get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, 'PPO epoch [% 6d]: collecting trajectories.',
                     self._epoch)
        self._rng, key = jax_random.split(self._rng)
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0)
        trajs = [(t[0], t[1], t[2], t[4]) for t in trajs]
        self._should_reset = False
        trajectory_collection_time = ppo.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, 'Collecting trajectories took %0.2f msec.',
                     trajectory_collection_time)

        rewards = np.array([np.sum(traj[2]) for traj in trajs])
        avg_reward = np.mean(rewards)
        std_reward = np.std(rewards)
        max_reward = np.max(rewards)
        min_reward = np.min(rewards)

        self._log('train', 'train/reward_mean_truncated', avg_reward)
        if evaluate and not self._separate_eval:
            metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}}
            ppo.write_eval_reward_summaries(metrics, self._log, self._epoch)

        logging.vlog(1,
                     'Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s',
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, 'Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]',
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, 'Trajectory Lengths: %s',
                     [len(traj[0]) for traj in trajs])

        preprocessing_start_time = time.time()
        (padded_observations, padded_actions, padded_rewards, reward_mask,
         padded_infos) = self._preprocess_trajectories(trajs)
        preprocessing_time = ppo.get_time(preprocessing_start_time)

        logging.vlog(1, 'Preprocessing trajectories took %0.2f msec.',
                     ppo.get_time(preprocessing_start_time))
        logging.vlog(1, 'Padded Observations\' shape [%s]',
                     str(padded_observations.shape))
        logging.vlog(1, 'Padded Actions\' shape [%s]',
                     str(padded_actions.shape))
        logging.vlog(1, 'Padded Rewards\' shape [%s]',
                     str(padded_rewards.shape))

        # Some assertions.
        B, RT = padded_rewards.shape  # pylint: disable=invalid-name
        B, AT = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, RT) == reward_mask.shape
        assert B == padded_observations.shape[0]

        log_prob_recompute_start_time = time.time()
        # TODO(pkozakowski): The following commented out code collects the network
        # predictions made while stepping the environment and uses them in PPO
        # training, so that we can use non-deterministic networks (e.g. with
        # dropout). This does not work well with serialization, so instead we
        # recompute all network predictions. Let's figure out a solution that will
        # work with both serialized sequences and non-deterministic networks.

        # assert ('log_prob_actions' in padded_infos and
        #         'value_predictions' in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        # actual_log_probabs_traj = padded_infos['log_prob_actions']
        # actual_value_predictions_traj = padded_infos['value_predictions']

        # assert (B, T, C) == actual_log_probabs_traj.shape[:3]
        # A = actual_log_probabs_traj.shape[3]  # pylint: disable=invalid-name
        # assert (B, T, 1) == actual_value_predictions_traj.shape

        del padded_infos

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with
        # (B, T, C, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        self._rng, key = jax_random.split(self._rng)

        (log_probabs_traj,
         value_predictions_traj) = (self._policy_and_value_net_apply(
             padded_observations,
             weights=self._policy_and_value_net_weights,
             state=self._model_state,
             rng=key,
         ))

        assert (B, AT) == log_probabs_traj.shape[:2]
        assert (B, AT) == value_predictions_traj.shape

        # TODO(pkozakowski): Commented out for the same reason as before.

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        # log_probabs_traj = np.concatenate(
        #     (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        # value_predictions_traj = np.concatenate(
        #     (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
        #     axis=1)

        log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time)

        # Compute value and ppo losses.
        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(2, 'Starting to compute P&V loss.')
        loss_compute_start_time = time.time()
        (cur_combined_loss, component_losses, summaries,
         self._model_state) = (ppo.combined_loss(
             self._policy_and_value_net_weights,
             log_probabs_traj,
             value_predictions_traj,
             self._policy_and_value_net_apply,
             padded_observations,
             padded_actions,
             self._rewards_to_actions,
             padded_rewards,
             reward_mask,
             nontrainable_params=self._nontrainable_params,
             state=self._model_state,
             rng=key1))
        loss_compute_time = ppo.get_time(loss_compute_start_time)
        (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses
        logging.vlog(
            1,
            'Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.',
            cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus,
            ppo.get_time(loss_compute_start_time))

        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(1, 'Policy and Value Optimization')
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=self._n_optimizer_steps)
        opt_step = 0
        opt_batch_size = min(self._optimizer_batch_size, B)
        index_batches = ppo.shuffled_index_batches(dataset_size=B,
                                                   batch_size=opt_batch_size)
        for (index_batch, key) in zip(index_batches, keys):
            k1, k2, k3 = jax_random.split(key, num=3)
            t = time.time()
            # Update the optimizer state on the sampled minibatch.
            self._policy_and_value_opt_state, self._model_state = (
                ppo.policy_and_value_opt_step(
                    # We pass the optimizer slots between PPO epochs, so we need to
                    # pass the optimization step as well, so for example the
                    # bias-correction in Adam is calculated properly. Alternatively we
                    # could reset the slots and the step in every PPO epoch, but then
                    # the moment estimates in adaptive optimizers would never have
                    # enough time to warm up. So it makes sense to reuse the slots,
                    # even though we're optimizing a different loss in every new
                    # epoch.
                    self._total_opt_step,
                    self._policy_and_value_opt_state,
                    self._policy_and_value_opt_update,
                    self._policy_and_value_get_params,
                    self._policy_and_value_net_apply,
                    log_probabs_traj[index_batch],
                    value_predictions_traj[index_batch],
                    padded_observations[index_batch],
                    padded_actions[index_batch],
                    self._rewards_to_actions,
                    padded_rewards[index_batch],
                    reward_mask[index_batch],
                    nontrainable_params=self._nontrainable_params,
                    state=self._model_state,
                    rng=k1))
            opt_step += 1
            self._total_opt_step += 1

            # Compute the approx KL for early stopping. Use the whole dataset - as we
            # only do inference, it should fit in the memory.
            (log_probab_actions_new, _) = (self._policy_and_value_net_apply(
                padded_observations,
                weights=self._policy_and_value_net_weights,
                state=self._model_state,
                rng=k2))

            action_mask = np.dot(np.pad(reward_mask, ((0, 0), (0, 1))),
                                 self._rewards_to_actions)
            approx_kl = ppo.approximate_kl(log_probab_actions_new,
                                           log_probabs_traj, action_mask)

            early_stopping = approx_kl > 1.5 * self._target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    'Early stopping policy and value optimization after %d steps, '
                    'with approx_kl: %0.2f', opt_step, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (opt_step % self._print_every_optimizer_steps == 0
                    or opt_step == self._n_optimizer_steps or early_stopping):
                # Compute and log the loss.
                (combined_loss, component_losses, _,
                 self._model_state) = (ppo.combined_loss(
                     self._policy_and_value_net_weights,
                     log_probabs_traj,
                     value_predictions_traj,
                     self._policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     self._rewards_to_actions,
                     padded_rewards,
                     reward_mask,
                     nontrainable_params=self._nontrainable_params,
                     state=self._model_state,
                     rng=k3))
                logging.vlog(
                    1, 'One Policy and Value grad desc took: %0.2f msec',
                    ppo.get_time(t, t2))
                (ppo_loss, value_loss, entropy_bonus) = component_losses
                logging.vlog(
                    1, 'Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->'
                    ' [%10.2f(%10.2f,%10.2f,%10.2f)]', cur_combined_loss,
                    combined_loss, ppo_loss, value_loss, entropy_bonus)

            if early_stopping:
                break

        optimization_time = ppo.get_time(optimization_start_time)

        logging.vlog(
            1, 'Total Combined Loss reduction [%0.2f]%%',
            (100 *
             (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss)))

        summaries.update({
            'n_optimizer_steps': opt_step,
            'approx_kl': approx_kl,
        })
        for (name, value) in summaries.items():
            self._log('train', 'train/{}'.format(name), value)

        logging.info(
            'PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined'
            ' Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]',
            self._epoch, min_reward, max_reward, avg_reward, combined_loss,
            ppo_loss, value_loss, entropy_bonus)

        # Bump the epoch counter before saving a checkpoint, so that a call to
        # save() after the training loop is a no-op if a checkpoint was saved last
        # epoch - otherwise it would bump the epoch counter on the checkpoint.
        last_epoch = self._epoch
        self._epoch += 1

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        policy_save_start_time = time.time()
        # TODO(afrozm): Refactor to trax.save_trainer_state.
        if (self._n_trajectories_done >=
                self._done_frac_for_policy_save * self.train_env.batch_size
                and self._epoch % self._save_every_n == 0) or self._async_mode:
            self.save()
        policy_save_time = ppo.get_time(policy_save_start_time)

        epoch_time = ppo.get_time(epoch_start_time)

        timing_dict = {
            'epoch': epoch_time,
            'policy_eval': policy_eval_time,
            'trajectory_collection': trajectory_collection_time,
            'preprocessing': preprocessing_time,
            'log_prob_recompute': log_prob_recompute_time,
            'loss_compute': loss_compute_time,
            'optimization': optimization_time,
            'policy_save': policy_save_time,
        }

        timing_dict.update(timing_info)

        if self._should_write_summaries:
            for k, v in timing_dict.items():
                self._timing_sw.scalar('timing/%s' % k, v, step=last_epoch)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            '%s : % 10.2f' % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info('PPO epoch [% 6d], Timings: \n%s', last_epoch,
                     '\n'.join(timing_info_list))

        # Flush summary writers once in a while.
        if self._epoch % 1000 == 0:
            self.flush_summaries()
예제 #17
0
def max(a, axis=None, keepdims=None, initial=None, where=None):
  if isinstance(a, JaxArray): a = a.value
  r = jnp.max(a, axis=axis, keepdims=keepdims, initial=initial, where=where)
  return r if axis is None else JaxArray(r)
예제 #18
0
def tei_array(geom, basis):
    """
    Build two electron integral array from a jax.numpy array of the cartesian geometry in Bohr, 
    and a basis dictionary as defined by basis_utils.build_basis_set
    We have to loop over primitives rather than shells because JAX needs intermediates to be consistent 
    sizes in order to compile.
    """
    # Smush primitive data together into vectors
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    max_am = jnp.max(ams)
    max_am_idx = max_am * 4 + 1 
    #TODO add excpetion raise if angular momentum is too high
    B_vals = jnp.zeros(4*max_am+1)  
    nprim = coeffs.shape[0]
    # Obtain all possible primitive quartet index combinations 
    primitive_quartets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim))

    #print("Number of basis functions: ", nbf)
    #print("Number of primitve quartets: ", primitive_quartets.shape[0])

    #TODO Experimental: precompute quantities and lookup inside loop
    # Compute all possible Gaussian products for this basis set
    aa_plus_bb = jnp.broadcast_to(exps, (nprim,nprim)) + jnp.transpose(jnp.broadcast_to(exps, (nprim,nprim)), (1,0))
    aa_times_A = jnp.einsum('i,ij->ij', exps, geom[atoms])
    aaxA_plus_bbxB = aa_times_A[:,None,:] + aa_times_A[None,:,:]
    gaussian_products = jnp.einsum('ijk,ij->ijk', aaxA_plus_bbxB, 1/aa_plus_bb)  

    # Compute all rab2 (rcd2), every possible jnp.dot(A-B,A-B)
    natom = geom.shape[0]
    tmpA = jnp.broadcast_to(geom, (natom,natom,3))
    AminusB = (tmpA - jnp.transpose(tmpA, (1,0,2)))
    AmBdot = jnp.einsum('ijk,ijk->ij', AminusB, AminusB) # shape: (natom,natom)

    # Compute all differences between gaussian product centers with all atom centers
    tmpP = jnp.tile(gaussian_products, natom).reshape(nprim,nprim,natom,3)
    PminusA = tmpP - jnp.broadcast_to(geom, tmpP.shape)

    # Commpute all powers (up to max_am) of differences between gaussian product centers and atom centers
    # Shape: (nprim, nprim, natom, 3, max_am+1). In loop index PA_pow as [p1,p2,atoms[p1],:,:]
    PminusA_pow = jnp.power(jnp.transpose(jnp.broadcast_to(PminusA, (max_am+1,nprim,nprim,natom,3)), (1,2,3,4,0)), jnp.arange(max_am+1))

    with loops.Scope() as s:
      s.G = jnp.zeros((nbf,nbf,nbf,nbf))
      s.a = 0  # center A angular momentum iterator 
      s.b = 0  # center B angular momentum iterator 
      s.c = 0  # center C angular momentum iterator 
      s.d = 0  # center D angular momentum iterator 

      # Loop over primitive quartets, compute integral, add to appropriate index in G
      for prim_quar in s.range(primitive_quartets.shape[0]):
        # Load in primitive indices, coeffs, exponents, centers, angular momentum index, and leading placement index in TEI array
        p1,p2,p3,p4 = primitive_quartets[prim_quar] 
        coef = coeffs[p1] * coeffs[p2] * coeffs[p3] * coeffs[p4]
        aa, bb, cc, dd = exps[p1], exps[p2], exps[p3], exps[p4]
        ld1, ld2, ld3, ld4 = am_leading_indices[ams[p1]],am_leading_indices[ams[p2]],am_leading_indices[ams[p3]],am_leading_indices[ams[p4]]
        idx1, idx2, idx3, idx4 = indices[p1],indices[p2],indices[p3],indices[p4],
        #A, B, C, D = geom[atoms[p1]], geom[atoms[p2]], geom[atoms[p3]], geom[atoms[p4]]

        # Compute common intermediates before looping over AM distributions.
        # Avoids redundant recomputations/reassignment for all classes other than (ss|ss).
        #AB = A - B
        #CD = C - D
        #rab2 = jnp.dot(AB,AB)
        #rcd2 = jnp.dot(CD,CD)
        #P = (aa * A + bb * B) / gamma1
        #Q = (cc * C + dd * D) / gamma2
        gamma1 = aa + bb
        gamma2 = cc + dd

        #TODO
        P = gaussian_products[p1,p2]
        Q = gaussian_products[p3,p4]
        rab2 = AmBdot[atoms[p1],atoms[p2]]
        rcd2 = AmBdot[atoms[p3],atoms[p4]]
        #PA = PminusA[p1,p2,atoms[p1]]
        #PB = PminusA[p1,p2,atoms[p2]]
        #QC = PminusA[p3,p4,atoms[p3]]
        #QD = PminusA[p3,p4,atoms[p4]]
        #TODO

        PQ = P - Q
        rpq2 = jnp.dot(PQ,PQ)
        delta = 0.25*(1/gamma1+1/gamma2)

        boys_arg = 0.25 * rpq2 / delta
        boys_eval = boys(jnp.arange(max_am_idx), boys_arg) 

        # Need all powers of Pi-Ai,Pi-Bi,Qi-Ci,Qi-Di (i=x,y,z) up to max_am and Qi-Pi up to max_am_idx
        # note: this computes unncessary quantities for lower angular momentum, 
        # but avoids repeated computation of the same quantities in loops for higher angular momentum

        #PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+1,3)).T, jnp.arange(max_am+1))
        #PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+1,3)).T, jnp.arange(max_am+1))
        #QC_pow = jnp.power(jnp.broadcast_to(Q-C, (max_am+1,3)).T, jnp.arange(max_am+1))
        #QD_pow = jnp.power(jnp.broadcast_to(Q-D, (max_am+1,3)).T, jnp.arange(max_am+1))

        PA_pow = PminusA_pow[p1,p2,atoms[p1],:,:]
        PB_pow = PminusA_pow[p1,p2,atoms[p2],:,:]
        QC_pow = PminusA_pow[p3,p4,atoms[p3],:,:]
        QD_pow = PminusA_pow[p3,p4,atoms[p4],:,:]

        QP_pow = jnp.power(jnp.broadcast_to(Q-P, (max_am_idx,3)).T, jnp.arange(max_am_idx))
        # Gamma powers are negative, up to -(l1+l2). 
        # Make array such that the given negative index returns the same negative power.
        g1_pow = jnp.power(4*gamma1, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) 
        g2_pow = jnp.power(4*gamma2, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) 
        oodelta_pow = jnp.power(1 / delta, jnp.arange(max_am_idx))  # l1 + l2 + l3 + l4 + 1

        prefactor = 34.986836655249726 / (gamma1*gamma2*jnp.sqrt(gamma1+gamma2)) \
                    * jnp.exp(-aa*bb*rab2/gamma1 + -cc*dd*rcd2/gamma2) * coef

        # TODO is there symmetry here?
        s.a = 0
        for _ in s.while_range(lambda: s.a < dims[p1]):
          s.b = 0
          for _ in s.while_range(lambda: s.b < dims[p2]):
            s.c = 0
            for _ in s.while_range(lambda: s.c < dims[p3]):
              s.d = 0
              for _ in s.while_range(lambda: s.d < dims[p4]):
                # Collect angular momentum and index in G
                la, ma, na = angular_momentum_combinations[s.a + ld1]
                lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                lc, mc, nc = angular_momentum_combinations[s.c + ld3]
                ld, md, nd = angular_momentum_combinations[s.d + ld4]
                i = idx1 + s.a
                j = idx2 + s.b
                k = idx3 + s.c
                l = idx4 + s.d
                # Compute the primitive quartet tei and add to appropriate index in G
                Bx = B_array(la,lb,lc,ld,PA_pow[0],PB_pow[0],QC_pow[0],QD_pow[0],QP_pow[0],g1_pow,g2_pow,oodelta_pow,B_vals)
                By = B_array(ma,mb,mc,md,PA_pow[1],PB_pow[1],QC_pow[1],QD_pow[1],QP_pow[1],g1_pow,g2_pow,oodelta_pow,B_vals)
                Bz = B_array(na,nb,nc,nd,PA_pow[2],PB_pow[2],QC_pow[2],QD_pow[2],QP_pow[2],g1_pow,g2_pow,oodelta_pow,B_vals)

                with loops.Scope() as S:
                  S.primitive = 0.
                  S.I = 0
                  S.J = 0
                  S.K = 0
                  for _ in S.while_range(lambda: S.I < la + lb + lc + ld + 1):
                    S.J = 0 
                    tmp = Bx[S.I] 
                    for _ in S.while_range(lambda: S.J < ma + mb + mc + md + 1):
                      S.K = 0 
                      tmp *= By[S.J] 
                      for _ in S.while_range(lambda: S.K < na + nb + nc + nd + 1):
                        tmp *= Bz[S.K] * boys_eval[S.I + S.J + S.K]
                        S.primitive += tmp
                        S.K += 1
                      S.J += 1
                    S.I += 1
                tei = prefactor * S.primitive
                s.G = jax.ops.index_add(s.G, jax.ops.index[i,j,k,l], tei) 

                s.d += 1
              s.c += 1
            s.b += 1
          s.a += 1
      return s.G
예제 #19
0
def _softmax(X):
    """Compute the softmax of matrix X in a numerically stable way."""
    shiftx = X - np.max(X, axis=1).reshape(-1, 1)
    exps = np.exp(shiftx)
    return exps / np.sum(exps, axis=1).reshape(-1, 1)
예제 #20
0
    #A = np.identity(5)

    B = np.array([[1., 0., 0.], [0.5, 0.5, 0.], [0.5, 0.0, 0.5], [0., 1., 0.],
                  [0., 0.5, 0.5]])

    A = 0.9 * A / np.linalg.norm(A, ord='nuc')
    B = 0.9 * B / np.linalg.norm(B, ord='nuc')

    #W = np.zeros(T_0 + T)
    W = (np.sin(np.arange((T_0 + T) * m) / (20 * np.pi)).reshape(
        (T_0 + T), m) @ np.ones((m, n))).reshape((T_0 + T), n, 1)

    sysid = SystemID()
    sysid.initialize(n, m)
    x = x0
    for t in range(T_0):
        u = sysid.get_action(x)
        x = A @ x + B @ u + W[t]
    A_id, B_id = sysid.system_id()
    print("A versus A_id")
    print(A)
    print(A_id)
    print("B versus B_id")
    print(B)
    print(B_id)
    print("max diff for A, B: ", np.max(np.abs(A - A_id)),
          np.max(np.abs(B - B_id)))
    print("norm diff for A, B: ", np.linalg.norm(A - A_id),
          np.linalg.norm(B - B_id))
예제 #21
0
def tree_max_abs(tree):
    return np.max(
        tree_flatten(tree_map(lambda arr: np.max(np.abs(arr)), tree))[0])
예제 #22
0
L1 = lambda z, xi: ydd1(z, xi) - Pe * yd1(z, xi)
L2 = lambda z, xi: ydd2(z, xi) - Pe * yd2(z, xi)

L = lambda xi: np.hstack((L1(z, xi), L2(z, xi)))

# Create the residual and jacobians
xi1 = onp.zeros(H(z).shape[1])
xi2 = onp.zeros(H(z).shape[1])
y = onp.zeros(1)
yd = onp.zeros(1)
b = onp.ones(1) * np.sqrt(2. / 0.5)

xi = TFCDict({'xi1': xi1, 'xi2': xi2, 'y': y, 'yd': yd, 'b': b})

## SOLVE THE SYSTEM *************************************************
xi, it, time = NLLS(xi, L, timer=True)

X = np.hstack((x1(z, xi), x2(z, xi)))
Y = np.hstack((y1(z, xi), y2(z, xi)))

# p1 = MakePlot(onp.array([['x (m)']]),onp.array([['y (m)']]))
# p1.ax[0].plot(X,Y,label='TFC Solution')
# p1.ax[0].plot(X,soln(X),label='Analytical Solution')
# p1.ax[0].legend()
# p1.show()

print('{:.2e} & {:.2e} & {:.5f} & {:.2f}'.format(np.max(np.abs(Y - soln(X))),
                                                 np.max(np.abs(L(xi))),
                                                 xp(xi)[0].tolist(), time))
    def _train_step(self):
        """Runs a single training step."""
        if self._replay.add_count > self.min_replay_history:
            if self.training_steps % self.update_period == 0:
                self._sample_from_replay_buffer()

                if self._replay_scheme == 'prioritized':
                    # The original prioritized experience replay uses a linear exponent
                    # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of
                    # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders)
                    # suggested a fixed exponent actually performs better, except on Pong.
                    probs = self.replay_elements['sampling_probabilities']
                    # Weight the loss by the inverse priorities.
                    loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
                    loss_weights /= jnp.max(loss_weights)
                else:
                    loss_weights = jnp.ones(
                        self.replay_elements['state'].shape[0])

                self.optimizer, aux_losses = train(
                    self.network_def, self.target_network_params,
                    self.optimizer, self.replay_elements['state'],
                    self.replay_elements['action'],
                    self.replay_elements['next_state'],
                    self.replay_elements['reward'],
                    self.replay_elements['terminal'], loss_weights,
                    self._support, self.cumulative_gamma, self._mico_weight,
                    self._distance_fn)

                loss = aux_losses.pop('loss')
                if self._replay_scheme == 'prioritized':
                    # Rainbow and prioritized replay are parametrized by an exponent
                    # alpha, but in both cases it is set to 0.5 - for simplicity's sake we
                    # leave it as is here, using the more direct sqrt(). Taking the square
                    # root "makes sense", as we are dealing with a squared loss.  Add a
                    # small nonzero value to the loss to avoid 0 priority items. While
                    # technically this may be okay, setting all items to 0 priority will
                    # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms.
                    self._replay.set_priority(self.replay_elements['indices'],
                                              jnp.sqrt(loss + 1e-10))

                if self._replay_scheme == 'prioritized':
                    probs = self.replay_elements['sampling_probabilities']
                    loss_weights = 1.0 / jnp.sqrt(probs + 1e-10)
                    loss_weights /= jnp.max(loss_weights)
                    self._replay.set_priority(self.replay_elements['indices'],
                                              jnp.sqrt(loss + 1e-10))
                    loss = loss_weights * loss
                if self.summary_writer is not None:
                    values = []
                    for k in aux_losses:
                        values.append(
                            tf.compat.v1.Summary.Value(
                                tag=f'Losses/{k}', simple_value=aux_losses[k]))
                    summary = tf.compat.v1.Summary(value=values)
                    self.summary_writer.add_summary(summary,
                                                    self.training_steps)
            if self.training_steps % self.target_update_period == 0:
                self._sync_weights()

        self.training_steps += 1
예제 #24
0
                           +x[0]*((1.+x[1]**3)*np.exp(-1.)-np.dot(H(np.ones_like(x[0]),x[1]),xi))
        u = lambda xi,*x: u1(xi,*x)\
                          +(1.-x[1])*(x[0]*np.exp(-x[0])-u1(xi,x[0],np.zeros_like(x[1])))\
                          +x[1]*(np.exp(-x[0])*(x[0]+1.)-u1(xi,x[0],np.ones_like(x[1])))

        # Create the residual
        laplace = lambda xi,*x: egrad(egrad(u,1),1)(xi,*x)+egrad(egrad(u,2),2)(xi,*x)
        L = lambda xi,*x: laplace(xi,*x)-np.exp(-x[0])*(x[0]-2.+x[1]**3+6.*x[1])

        # Calculate the xi values
        zXi = np.zeros(H(*x).shape[1])
        A = jacfwd(L,0)(zXi,*x)
        B = -L(zXi,*x)
        xi = np.dot(np.linalg.pinv(A),B)

        # Calculate the error
        dark = np.meshgrid(np.linspace(x0[0],xf[0],n),np.linspace(x0[1],xf[1],n))
        x = (dark[0].flatten(),dark[1].flatten())

        ur = real(*x)
        ue = u(xi,*x)
        err = ur-ue
        testErr[j,k] = np.max(np.abs(err))

# Print results as a table
tab = table.SimpleTable(testErr)
print(tab)
f = open("TfcData.txt","w")
f.write(tab)
f.close()
예제 #25
0
def softmax(x, axis=-1):
    unnormalized = np.exp(x - np.max(x, axis, keepdims=True))
    return unnormalized / np.sum(unnormalized, axis, keepdims=True)
예제 #26
0
def get_particle_lims(particles):
    """particles is a (n, 2) array of 2d points.
    return lims (-a, a) st particles fit into square with
    corners at +- a."""
    a = np.max(np.abs(particles))
    return (-a, a)
예제 #27
0
def matrix_inverse_pth_root(mat_g,
                            p,
                            iter_count=100,
                            error_tolerance=1e-6,
                            ridge_epsilon=1e-6):
    """Computes mat_g^(-1/p), where p is a positive integer.

  Coupled newton iterations for matrix inverse pth root.

  Args:
    mat_g: the symmetric PSD matrix whose power it to be computed
    p: exponent, for p a positive integer.
    iter_count: Maximum number of iterations.
    error_tolerance: Error indicator, useful for early termination.
    ridge_epsilon: Ridge epsilon added to make the matrix positive definite.

  Returns:
    mat_g^(-1/p)
  """
    mat_g_size = mat_g.shape[0]
    alpha = jnp.asarray(-1.0 / p, _INVERSE_PTH_ROOT_DATA_TYPE)
    identity = jnp.eye(mat_g_size, dtype=_INVERSE_PTH_ROOT_DATA_TYPE)
    _, max_ev, _ = power_iter(mat_g)
    ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)

    def _unrolled_mat_pow_1(mat_m):
        """Computes mat_m^1."""
        return mat_m

    def _unrolled_mat_pow_2(mat_m):
        """Computes mat_m^2."""
        return jnp.matmul(mat_m, mat_m, precision=_INVERSE_PTH_ROOT_PRECISION)

    def _unrolled_mat_pow_4(mat_m):
        """Computes mat_m^4."""
        mat_pow_2 = _unrolled_mat_pow_2(mat_m)
        return jnp.matmul(mat_pow_2,
                          mat_pow_2,
                          precision=_INVERSE_PTH_ROOT_PRECISION)

    def _unrolled_mat_pow_8(mat_m):
        """Computes mat_m^4."""
        mat_pow_4 = _unrolled_mat_pow_4(mat_m)
        return jnp.matmul(mat_pow_4,
                          mat_pow_4,
                          precision=_INVERSE_PTH_ROOT_PRECISION)

    def mat_power(mat_m, p):
        """Computes mat_m^p, for p == 1, 2, 4 or 8.

    Args:
      mat_m: a square matrix
      p: a positive integer

    Returns:
      mat_m^p
    """
        # We unrolled the loop for performance reasons.
        exponent = jnp.round(jnp.log2(p))
        return lax.switch(jnp.asarray(exponent, jnp.int32), [
            _unrolled_mat_pow_1,
            _unrolled_mat_pow_2,
            _unrolled_mat_pow_4,
            _unrolled_mat_pow_8,
        ], (mat_m))

    def _iter_condition(state):
        (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
         run_step) = state
        error_above_threshold = jnp.logical_and(error > error_tolerance,
                                                run_step)
        return jnp.logical_and(i < iter_count, error_above_threshold)

    def _iter_body(state):
        (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
        mat_m_i = (1 - alpha) * identity + alpha * mat_m
        new_mat_m = jnp.matmul(mat_power(mat_m_i, p),
                               mat_m,
                               precision=_INVERSE_PTH_ROOT_PRECISION)
        new_mat_h = jnp.matmul(mat_h,
                               mat_m_i,
                               precision=_INVERSE_PTH_ROOT_PRECISION)
        new_error = jnp.max(jnp.abs(new_mat_m - identity))
        # sometimes error increases after an iteration before decreasing and
        # converging. 1.2 factor is used to bound the maximal allowed increase.
        return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
                new_error < error * 1.2)

    if mat_g_size == 1:
        resultant_mat_h = (mat_g + ridge_epsilon)**alpha
        error = 0
    else:
        damped_mat_g = mat_g + ridge_epsilon * identity
        z = (1 + p) / (2 * jnp.linalg.norm(damped_mat_g))
        new_mat_m_0 = damped_mat_g * z
        new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
        new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
        init_state = tuple(
            [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
        _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
            _iter_condition, _iter_body, init_state)
        error = jnp.max(jnp.abs(mat_m - identity))
        is_converged = jnp.asarray(convergence, old_mat_h.dtype)
        resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
        resultant_mat_h = jnp.asarray(resultant_mat_h, mat_g.dtype)
    return resultant_mat_h, error
예제 #28
0
        val['dxi'].block_until_ready()
        toc = timer()
        time += (toc-tic)

    it += 1

    # print(xp)
    return np.max(np.abs(L(z,val['xi'],val['xp']))).tolist()

xp_0 = 0.75
xp = optim.fsolve(fMin, xp_0, xtol=tol,epsfcn=tol)

if xp > xpBound:
    xp = xpBound

val['xp'] = xp
val = nlls(val)
xi = val['xi']

X = np.hstack((x1(z,xp), x2(z,xp)))
Y = np.hstack((y1(z,xi,xp), y2(z,xi,xp) ))


# p1 = MakePlot(onp.array([['x (m)']]),onp.array([['y (m)']]))
# p1.ax[0].plot(X,Y,label='TFC Solution')
# p1.ax[0].plot(X,soln(X),label='Analytical Solution')
# p1.ax[0].legend()
# p1.show()

print('{:.2e} & {:.2e} & {:.5f} & {:.2f}'.format(np.max(np.abs(Y - soln(X))), np.max(np.abs(L(z,xi,xp))), xp, time ))
예제 #29
0
 def minmax(x):
   x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2)
   x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
   x0 = jnp.minimum(x0, x[Ellipsis, -2:-1])
   x1 = jnp.maximum(x1, x[Ellipsis, 1:2])
   return x0, x1
예제 #30
0
    def apply(self, x, communication=Communication.NONE, train=True):
        """Forward pass."""
        batch_size = x.shape[0]

        if communication is Communication.SQUEEZE_EXCITE_X:
            x = sample_patches.SqueezeExciteLayer(x)
        # end if squeeze excite x

        d1 = nn.relu(
            nn.Conv(x,
                    128,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    bias=True,
                    name="down1"))
        d2 = nn.relu(
            nn.Conv(d1,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down2"))
        d3 = nn.relu(
            nn.Conv(d2,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down3"))

        if communication is Communication.SQUEEZE_EXCITE_D:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            num_channels = d_together.shape[-1]
            y = d_together.mean(axis=1)
            y = nn.Dense(y, features=num_channels // 4, bias=False)
            y = nn.relu(y)
            y = nn.Dense(y, features=num_channels, bias=False)
            y = nn.sigmoid(y)

            d_together = d_together * y[:, None, :]

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        elif communication is Communication.TRANSFORMER:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            positional_encodings = self.param(
                "scale_ratio_position_encodings",
                shape=(1, ) + d_together.shape[1:],
                initializer=jax.nn.initializers.normal(1. /
                                                       d_together.shape[-1]))
            d_together = transformer.Transformer(d_together +
                                                 positional_encodings,
                                                 num_layers=2,
                                                 num_heads=8,
                                                 is_training=train)

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        t1 = nn.Conv(d1,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy1")
        t2 = nn.Conv(d2,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy2")
        t3 = nn.Conv(d3,
                     9,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy3")

        raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) +
                      jnp.split(t3, 9, axis=-1))

        # The following is for normalization.
        t = jnp.concatenate((jnp.reshape(
            t1, [batch_size, -1]), jnp.reshape(
                t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])),
                            axis=1)
        t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])
        t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])
        normalized_scores = zeroone(raw_scores, t_min, t_max)

        stats = {
            "scores": normalized_scores,
            "raw_scores": t,
        }
        # removes the split dimension. scores are now b x h' x w' shaped
        normalized_scores = [s.squeeze(-1) for s in normalized_scores]

        return normalized_scores, stats