Beispiel #1
0
    return [vals[n.name] for n in graph.output]


if __name__ == "__main__":
    # It seems that there are several ONNX proto versions (you had one job!) but
    # this implementation works with at least this one mnist example file.
    url = ('https://github.com/onnx/models/blob/'
           '81c4779096d1205edd0b809e191a924c58c38fef/'
           'mnist/model.onnx?raw=true')
    download = urlopen(url).read()
    if hashlib.md5(download).hexdigest() != 'bc8ad9bd19c5a058055dc18d0f089dad':
        print("onnx file checksum mismatch")
        sys.exit(1)
    model = onnx.load(StringIO(download))

    predict = lambda inputs: interpret_onnx(model.graph, inputs)[0]

    # Run inference in Numpy-backed interpreter
    print("interpreted:")
    print(predict(np.ones((1, 1, 28, 28))))

    # JIT compile to XLA device, run inference on device
    compiled_predict = jit(predict)
    print("compiled:")
    print(compiled_predict(np.ones((1, 1, 28, 28))))

    # The interpreter is differentiable too! Even the compiled one:
    fun = lambda inputs: np.sum(compiled_predict(inputs))
    print("a derivative with respect to inputs:")
    print(grad(fun)(np.ones((1, 1, 28, 28)))[..., :3, :3])
Beispiel #2
0
def training_loop(env=None,
                  epochs=EPOCHS,
                  policy_and_value_net_fun=None,
                  policy_and_value_optimizer_fun=None,
                  batch_size=BATCH_TRAJECTORIES,
                  num_optimizer_steps=NUM_OPTIMIZER_STEPS,
                  print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                  target_kl=0.01,
                  boundary=20,
                  max_timestep=None,
                  max_timestep_eval=20000,
                  random_seed=None,
                  gamma=GAMMA,
                  lambda_=LAMBDA,
                  epsilon=EPSILON,
                  c1=1.0,
                  c2=0.01,
                  output_dir=None,
                  eval_every_n=1000,
                  eval_env=None,
                  done_frac_for_policy_save=0.5,
                  enable_early_stopping=True):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    assert env
    assert output_dir

    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (-1, -1) + env.observation_space.shape

    assert isinstance(env.action_space, gym.spaces.Discrete)
    num_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fun(key1, batch_observations_shape, num_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fun(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    num_trajectories_done = 0

    for i in range(iteration, epochs):

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        if (i % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            avg_reward = evaluate_policy(eval_env,
                                         get_predictions,
                                         boundary,
                                         max_timestep=max_timestep_eval,
                                         rng=key)
            for k, v in avg_reward.items():
                eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i)
                logging.info("Epoch [% 6d] Policy Evaluation [%s] = %10.2f", i,
                             k, v)

        t = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, num_done = collect_trajectories(
            env,
            policy_fun=get_predictions,
            num_trajectories=batch_size,
            max_timestep=max_timestep,
            boundary=boundary,
            rng=key,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

        # Save parameters every time we see the end of atleast a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Or if this is the last iteration.
        num_trajectories_done += num_done
        if ((num_trajectories_done >= done_frac_for_policy_save * batch_size)
                or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Reset this number.
            num_trajectories_done = 0

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     get_time(t))

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/mean_reward", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        t = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)

        logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Calculate log-probabilities and value predictions of the trajectories.
        # We'll pass these to the loss functions so as to not get recomputed.

        # NOTE:
        # There is a slight problem here, if the policy network contains
        # stochasticity in the log-probabilities (ex: dropout), then calculating
        # these again here is not going to be correct and should be done in the
        # collect function.

        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        t = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(t))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        if policy_and_value_net_apply:
            logging.vlog(1, "Policy and Value Optimization")
            t1 = time.time()
            keys = jax_random.split(key1, num=num_optimizer_steps)
            for j in range(num_optimizer_steps):
                k1, k2, k3 = jax_random.split(keys[j], num=3)
                t = time.time()
                # Update the optimizer state.
                policy_and_value_opt_state = policy_and_value_opt_step(
                    j,
                    policy_and_value_opt_state,
                    policy_and_value_opt_update,
                    policy_and_value_get_params,
                    policy_and_value_net_apply,
                    log_probabs_traj,
                    value_predictions_traj,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    c1=c1,
                    c2=c2,
                    gamma=gamma,
                    lambda_=lambda_,
                    epsilon=epsilon_schedule,
                    rng=k1)

                # Compute the approx KL for early stopping.
                new_policy_and_value_net_params = policy_and_value_get_params(
                    policy_and_value_opt_state)

                log_probab_actions_new, _ = policy_and_value_net_apply(
                    padded_observations,
                    new_policy_and_value_net_params,
                    rng=k2)

                approx_kl = approximate_kl(log_probab_actions_new,
                                           log_probabs_traj, reward_mask)

                early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
                if early_stopping:
                    logging.vlog(
                        1,
                        "Early stopping policy and value optimization at iter: %d, "
                        "with approx_kl: %0.2f", j, approx_kl)
                    # We don't return right-away, we want the below to execute on the last
                    # iteration.

                t2 = time.time()
                if (((j + 1) % print_every_optimizer_steps == 0)
                        or (j == num_optimizer_steps - 1) or early_stopping):
                    # Compute and log the loss.
                    (loss_combined, loss_ppo, loss_value,
                     entropy_bonus) = (combined_loss(
                         new_policy_and_value_net_params,
                         log_probabs_traj,
                         value_predictions_traj,
                         policy_and_value_net_apply,
                         padded_observations,
                         padded_actions,
                         padded_rewards,
                         reward_mask,
                         gamma=gamma,
                         lambda_=lambda_,
                         epsilon=epsilon_schedule,
                         c1=c1,
                         c2=c2,
                         rng=k3))
                    logging.vlog(
                        1, "One Policy and Value grad desc took: %0.2f msec",
                        get_time(t, t2))
                    logging.vlog(
                        1,
                        "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                        " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                        loss_combined, loss_value, loss_ppo, entropy_bonus)

                if early_stopping:
                    break

            logging.vlog(1, "Total Combined Loss reduction [%0.2f]%%",
                         (100 * (cur_combined_loss - loss_combined) /
                          np.abs(cur_combined_loss)))

            logging.info(
                "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
                " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)], took "
                "[%2.5f msec].", i, min_reward, max_reward,
                avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus,
                get_time(t1))

            restore = False
Beispiel #3
0
 def f(x1):
     x2 = jnp.sin(x1)
     x3 = jnp.sin(x2)
     x4 = jnp.sin(x3)
     return jnp.sum(x4)
Beispiel #4
0
 def g_fn(R):
   dim = R.shape[-1]
   mask = 1 - jnp.eye(R.shape[0], dtype=R.dtype)
   return jnp.sum(mask[:, :, jnp.newaxis] *
                  pairwise(d(R, R), dim), axis=(1,))
 def f(c, a):
   b = np.sin(c * np.sum(np.cos(d * a)))
   c = 0.9 * np.cos(d * np.sum(np.sin(c * a)))
   return c, b
Beispiel #6
0
def kinetic_energy(V, mass=1.0):
    """Computes the kinetic energy of a system with some velocities."""
    return 0.5 * np.sum(mass * V**2)
Beispiel #7
0
def nonbonded_v3(
    conf,
    params,
    box,
    lamb,
    charge_rescale_mask,
    lj_rescale_mask,
    beta,
    cutoff,
    lambda_plane_idxs,
    lambda_offset_idxs,
    runtime_validate=True,
):
    """Lennard-Jones + Coulomb, with a few important twists:
    * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs
    * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter
    * Coulomb terms are multiplied by erfc(beta * distance)

    Parameters
    ----------
    conf : (N, 3) or (N, 4) np.array
        3D or 4D coordinates
        if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w)
            where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb)
    params : (N, 3) np.array
        columns [charges, sigmas, epsilons], one row per particle
    box : Optional 3x3 np.array
    lamb : float
    charge_rescale_mask : (N, N) np.array
        the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j]
    lj_rescale_mask : (N, N) np.array
        the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j]
    beta : float
        the charge product q_ij will be multiplied by erfc(beta*d_ij)
    cutoff : Optional float
        a pair of particles (i,j) will be considered non-interacting if the distance d_ij
        between their 4D coordinates exceeds cutoff
    lambda_plane_idxs : Optional (N,) np.array
    lambda_offset_idxs : Optional (N,) np.array
    runtime_validate: bool
        check whether beta is compatible with cutoff
        (if True, this function will currently not play nice with Jax JIT)
        TODO: is there a way to conditionally print a runtime warning inside
            of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError?

    Returns
    -------
    energy : float

    References
    ----------
    * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial
        dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750
    * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large
    systems" https://aip.scitation.org/doi/abs/10.1063/1.470117
        * Coulomb interactions are treated using the direct-space contribution from eq 2
    """
    if runtime_validate:
        assert (charge_rescale_mask == charge_rescale_mask.T).all()
        assert (lj_rescale_mask == lj_rescale_mask.T).all()

    N = conf.shape[0]

    if conf.shape[-1] == 3:
        conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                             cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        if box.shape[-1] == 3:
            box_4d = np.eye(4) * 1000
            box_4d = index_update(box_4d, index[:3, :3], box)
        else:
            box_4d = box
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    dij = distance(conf, box)

    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        if runtime_validate:
            validate_coulomb_cutoff(cutoff, beta, threshold=1e-2)
        eps_ij = np.where(dij < cutoff, eps_ij, 0)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, 0)
    eps_ij = np.where(keep_mask, eps_ij, 0)

    inv_dij = 1 / dij
    inv_dij = np.where(np.eye(N), 0, inv_dij)

    sig2 = sig_ij * inv_dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, 0)

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(N)
    qij = np.where(keep_mask, qij, 0)
    dij = np.where(keep_mask, dij, 0)

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) * inv_dij,
                          0)  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, 0, eij_charge)

    eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask

    return np.sum(eij_total / 2)
 def harmonic_bond(conf, params):
   return np.sum(conf * params)
Beispiel #9
0
    def train_epoch(self, evaluate=True):
        """Train one PPO epoch."""
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self._epoch + 1) % self._eval_every_n == 0:
            self._rng, key = jax_random.split(self._rng, num=2)
            self.evaluate()

        policy_eval_time = ppo.get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, 'PPO epoch [% 6d]: collecting trajectories.',
                     self._epoch)
        self._rng, key = jax_random.split(self._rng)
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0)
        trajs = [(t[0], t[1], t[2], t[4]) for t in trajs]
        self._should_reset = False
        trajectory_collection_time = ppo.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, 'Collecting trajectories took %0.2f msec.',
                     trajectory_collection_time)

        rewards = np.array([np.sum(traj[2]) for traj in trajs])
        avg_reward = np.mean(rewards)
        std_reward = np.std(rewards)
        max_reward = np.max(rewards)
        min_reward = np.min(rewards)

        self._log('train', 'train/reward_mean_truncated', avg_reward)
        if evaluate and not self._separate_eval:
            metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}}
            ppo.write_eval_reward_summaries(metrics, self._log, self._epoch)

        logging.vlog(1,
                     'Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s',
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, 'Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]',
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, 'Trajectory Lengths: %s',
                     [len(traj[0]) for traj in trajs])

        preprocessing_start_time = time.time()
        (padded_observations, padded_actions, padded_rewards, reward_mask,
         padded_infos) = self._preprocess_trajectories(trajs)
        preprocessing_time = ppo.get_time(preprocessing_start_time)

        logging.vlog(1, 'Preprocessing trajectories took %0.2f msec.',
                     ppo.get_time(preprocessing_start_time))
        logging.vlog(1, 'Padded Observations\' shape [%s]',
                     str(padded_observations.shape))
        logging.vlog(1, 'Padded Actions\' shape [%s]',
                     str(padded_actions.shape))
        logging.vlog(1, 'Padded Rewards\' shape [%s]',
                     str(padded_rewards.shape))

        # Some assertions.
        B, RT = padded_rewards.shape  # pylint: disable=invalid-name
        B, AT = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, RT) == reward_mask.shape
        assert B == padded_observations.shape[0]

        log_prob_recompute_start_time = time.time()
        # TODO(pkozakowski): The following commented out code collects the network
        # predictions made while stepping the environment and uses them in PPO
        # training, so that we can use non-deterministic networks (e.g. with
        # dropout). This does not work well with serialization, so instead we
        # recompute all network predictions. Let's figure out a solution that will
        # work with both serialized sequences and non-deterministic networks.

        # assert ('log_prob_actions' in padded_infos and
        #         'value_predictions' in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        # actual_log_probabs_traj = padded_infos['log_prob_actions']
        # actual_value_predictions_traj = padded_infos['value_predictions']

        # assert (B, T, C) == actual_log_probabs_traj.shape[:3]
        # A = actual_log_probabs_traj.shape[3]  # pylint: disable=invalid-name
        # assert (B, T, 1) == actual_value_predictions_traj.shape

        del padded_infos

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with
        # (B, T, C, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        self._rng, key = jax_random.split(self._rng)

        (log_probabs_traj,
         value_predictions_traj) = (self._policy_and_value_net_apply(
             padded_observations,
             params=self._policy_and_value_net_params,
             state=self._model_state,
             rng=key,
         ))

        assert (B, AT) == log_probabs_traj.shape[:2]
        assert (B, AT) == value_predictions_traj.shape

        # TODO(pkozakowski): Commented out for the same reason as before.

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        # log_probabs_traj = np.concatenate(
        #     (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        # value_predictions_traj = np.concatenate(
        #     (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
        #     axis=1)

        log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time)

        # Compute value and ppo losses.
        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(2, 'Starting to compute P&V loss.')
        loss_compute_start_time = time.time()
        (cur_combined_loss, component_losses, summaries,
         self._model_state) = (ppo.combined_loss(
             self._policy_and_value_net_params,
             log_probabs_traj,
             value_predictions_traj,
             self._policy_and_value_net_apply,
             padded_observations,
             padded_actions,
             self._rewards_to_actions,
             padded_rewards,
             reward_mask,
             nontrainable_params=self._nontrainable_params,
             state=self._model_state,
             rng=key1))
        loss_compute_time = ppo.get_time(loss_compute_start_time)
        (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses
        logging.vlog(
            1,
            'Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.',
            cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus,
            ppo.get_time(loss_compute_start_time))

        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(1, 'Policy and Value Optimization')
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=self._n_optimizer_steps)
        opt_step = 0
        opt_batch_size = min(self._optimizer_batch_size, B)
        index_batches = ppo.shuffled_index_batches(dataset_size=B,
                                                   batch_size=opt_batch_size)
        for (index_batch, key) in zip(index_batches, keys):
            k1, k2, k3 = jax_random.split(key, num=3)
            t = time.time()
            # Update the optimizer state on the sampled minibatch.
            self._policy_and_value_opt_state, self._model_state = (
                ppo.policy_and_value_opt_step(
                    # We pass the optimizer slots between PPO epochs, so we need to
                    # pass the optimization step as well, so for example the
                    # bias-correction in Adam is calculated properly. Alternatively we
                    # could reset the slots and the step in every PPO epoch, but then
                    # the moment estimates in adaptive optimizers would never have
                    # enough time to warm up. So it makes sense to reuse the slots,
                    # even though we're optimizing a different loss in every new
                    # epoch.
                    self._total_opt_step,
                    self._policy_and_value_opt_state,
                    self._policy_and_value_opt_update,
                    self._policy_and_value_get_params,
                    self._policy_and_value_net_apply,
                    log_probabs_traj[index_batch],
                    value_predictions_traj[index_batch],
                    padded_observations[index_batch],
                    padded_actions[index_batch],
                    self._rewards_to_actions,
                    padded_rewards[index_batch],
                    reward_mask[index_batch],
                    nontrainable_params=self._nontrainable_params,
                    state=self._model_state,
                    rng=k1))
            opt_step += 1
            self._total_opt_step += 1

            # Compute the approx KL for early stopping. Use the whole dataset - as we
            # only do inference, it should fit in the memory.
            (log_probab_actions_new, _) = (self._policy_and_value_net_apply(
                padded_observations,
                params=self._policy_and_value_net_params,
                state=self._model_state,
                rng=k2))

            action_mask = np.dot(np.pad(reward_mask, ((0, 0), (0, 1))),
                                 self._rewards_to_actions)
            approx_kl = ppo.approximate_kl(log_probab_actions_new,
                                           log_probabs_traj, action_mask)

            early_stopping = approx_kl > 1.5 * self._target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    'Early stopping policy and value optimization after %d steps, '
                    'with approx_kl: %0.2f', opt_step, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (opt_step % self._print_every_optimizer_steps == 0
                    or opt_step == self._n_optimizer_steps or early_stopping):
                # Compute and log the loss.
                (combined_loss, component_losses, _,
                 self._model_state) = (ppo.combined_loss(
                     self._policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     self._policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     self._rewards_to_actions,
                     padded_rewards,
                     reward_mask,
                     nontrainable_params=self._nontrainable_params,
                     state=self._model_state,
                     rng=k3))
                logging.vlog(
                    1, 'One Policy and Value grad desc took: %0.2f msec',
                    ppo.get_time(t, t2))
                (ppo_loss, value_loss, entropy_bonus) = component_losses
                logging.vlog(
                    1, 'Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->'
                    ' [%10.2f(%10.2f,%10.2f,%10.2f)]', cur_combined_loss,
                    combined_loss, ppo_loss, value_loss, entropy_bonus)

            if early_stopping:
                break

        optimization_time = ppo.get_time(optimization_start_time)

        logging.vlog(
            1, 'Total Combined Loss reduction [%0.2f]%%',
            (100 *
             (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss)))

        summaries.update({
            'n_optimizer_steps': opt_step,
            'approx_kl': approx_kl,
        })
        for (name, value) in summaries.items():
            self._log('train', 'train/{}'.format(name), value)

        logging.info(
            'PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined'
            ' Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]',
            self._epoch, min_reward, max_reward, avg_reward, combined_loss,
            ppo_loss, value_loss, entropy_bonus)

        # Bump the epoch counter before saving a checkpoint, so that a call to
        # save() after the training loop is a no-op if a checkpoint was saved last
        # epoch - otherwise it would bump the epoch counter on the checkpoint.
        last_epoch = self._epoch
        self._epoch += 1

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        policy_save_start_time = time.time()
        # TODO(afrozm): Refactor to trax.save_trainer_state.
        if (self._n_trajectories_done >=
                self._done_frac_for_policy_save * self.train_env.batch_size
                and self._epoch % self._save_every_n == 0) or self._async_mode:
            self.save()
        policy_save_time = ppo.get_time(policy_save_start_time)

        epoch_time = ppo.get_time(epoch_start_time)

        timing_dict = {
            'epoch': epoch_time,
            'policy_eval': policy_eval_time,
            'trajectory_collection': trajectory_collection_time,
            'preprocessing': preprocessing_time,
            'log_prob_recompute': log_prob_recompute_time,
            'loss_compute': loss_compute_time,
            'optimization': optimization_time,
            'policy_save': policy_save_time,
        }

        timing_dict.update(timing_info)

        if self._should_write_summaries:
            for k, v in timing_dict.items():
                self._timing_sw.scalar('timing/%s' % k, v, step=last_epoch)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            '%s : % 10.2f' % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info('PPO epoch [% 6d], Timings: \n%s', last_epoch,
                     '\n'.join(timing_info_list))

        # Flush summary writers once in a while.
        if self._epoch % 1000 == 0:
            self.flush_summaries()
 def loss(params, inputs, labels):
     predictions = predict_fn(params, inputs)
     return -jnp.mean(jnp.sum(predictions * labels, axis=1))
Beispiel #11
0
def pmds_MAP(
    p_dists,
    n_samples,
    n_components=2,
    batch_size=0,
    random_state=42,
    lr=1e-3,
    epochs=20,
    debug_D_squareform=None,
    fixed_points=[],
    init_mu=None,
    hard_fix=False,
    method="MAP",
):
    """Probabilistic MDS according to Hefner model 1958.

    Parameters
    ----------
    p_dists : list(float) or list(tuple(int, int, float))
        List of input pairwise distances.
        Can be a list of scalar [d_{ij}],
        or a list of pairwise distances with indices [(i, j), d_{ij}]
    n_samples : int
        Number of points in the dataset.
    n_components : int, defaults to 2
        Number of output dimensions in the LD space
        Now only accept 2 or 4.
    batch_size : int, defaults to 0 meaning that to use all pairs in a batch
        Number of pairs processed in parallel using jax.vmap
    random_state : int, defaults to 42
        random_state for jax random generator for params initialization
    lr : float, defaults to 1e-3
        learning rate for standard SGD
    epochs : int, defaults to 20
        Number of epochs
    fixed_points: list(tuple(int, float, float)), defaults to []
        list of fixed points (index, x, y)
    init_mu: ndarray[float], (n_samples, n_components), defaults to None
        initial position for the embedding

    Returns:
    --------
    mu : ndarray (n_samples, n_components)
        Location estimation for points in LD space.
    ss : ndarray (n_samples,)
        Sigma square, variance estimation for each point.
    all_loss : list of float
        List of loss values for each iteration.
    """
    assert n_components in [2, 4]

    # init mu and sigma square. Transform unconstrained sigma square `ss_unc` to `ss`.
    # https://github.com/tensorflow/probability/issues/703
    key_m, key_s = jax.random.split(jax.random.PRNGKey(random_state))
    # ss_unc = jax.random.normal(key_s, (n_samples,))
    ss_unc = jnp.ones((n_samples, ))
    if init_mu is not None and init_mu.shape == (n_samples, n_components):
        mu = jnp.array(init_mu)
    else:
        mu = jax.random.normal(key_m, (n_samples, n_components))

    # fixed points
    if fixed_points:
        fixed_indices = [p[0] for p in fixed_points]
        fixed_pos = jnp.array([[p[1], p[2]] for p in fixed_points])
        mu = jax.ops.index_update(mu, fixed_indices, fixed_pos)
        ss_unc = jax.ops.index_update(ss_unc, fixed_indices, EPSILON)

    # patch pairwise distances and indices of each pairs together
    if isinstance(p_dists[0], float):
        all_pairs = list(combinations(range(n_samples), 2))
        assert len(p_dists) == len(all_pairs)
        dists_with_indices = list(zip(p_dists, all_pairs))
    else:
        dists_with_indices = p_dists

    all_loss = []
    for epoch in range(epochs):
        # shuffle the observed pairs in each epoch
        batch = random.sample(dists_with_indices, k=len(p_dists))
        # unpatch pairwise distances and indices of points in each pair
        dists, pair_indices = list(zip(*batch))
        i0, i1 = list(zip(*pair_indices))
        i0, i1 = list(i0), list(i1)

        # get the params for related indices from global `mu` and `ss`
        mu_i, mu_j = mu[i0], mu[i1]
        ss_i = EPSILON + jax.nn.softplus(SCALE * ss_unc[i0])
        ss_j = EPSILON + jax.nn.softplus(SCALE * ss_unc[i1])

        # calculate loss and gradients of the log likelihood term
        loss_lllh, grads_lllh = loss_and_grads_lllh(mu_i, mu_j, ss_i, ss_j,
                                                    jnp.array(dists),
                                                    n_components)

        # calculate loss and gradients of prior term
        loss_log_mu, grads_log_mu = loss_and_grads_log_mu(mu)

        # accumulate log likelihood and log prior
        loss = jnp.sum(loss_lllh) + jnp.sum(loss_log_mu)

        # print("[DEBUG]: NAN here?", jnp.mean(loss_lllh), jnp.mean(loss_log_mu))

        # update gradient for the corresponding related indices
        grads_mu = jnp.concatenate((lr * grads_lllh[0], lr * grads_lllh[1]),
                                   axis=0)
        grads_ss = jnp.concatenate((lr * grads_lllh[2], lr * grads_lllh[3]),
                                   axis=0)
        related_indices = i0 + i1
        assert grads_mu.shape[0] == grads_ss.shape[0] == len(related_indices)

        # update gradient for mu
        mu = jax.ops.index_add(mu, related_indices, -grads_mu)
        mu = mu - lr * grads_log_mu[0]

        # update gradient for constrained variable ss
        # first, calculate gradient for unconstrained variable ss_unc
        grads_ss_unc = (grads_ss *
                        jax.nn.sigmoid(SCALE * ss_unc[related_indices]) *
                        SCALE)
        # then, update the unconstrained variable ss_unc
        ss_unc = jax.ops.index_add(ss_unc, related_indices,
                                   -grads_ss_unc / len(i0))

        # correct gradient for fixed points
        if fixed_points and hard_fix:
            mu = jax.ops.index_update(mu, fixed_indices, fixed_pos)
            ss_unc = jax.ops.index_update(ss_unc, fixed_indices, EPSILON)

        mds_stress = (stress(debug_D_squareform, mu)
                      if debug_D_squareform is not None else 0.0)
        all_loss.append(loss)

        # mlflow.log_metric("loss", loss)
        # mlflow.log_metric("stress", mds_stress)
        print(
            f"[DEBUG] epoch {epoch}, loss: {loss:.2f}, stress: {mds_stress:,.2f}"
            # f" mu in [{float(jnp.min(mu)):.3f}, {float(jnp.max(mu)):.3f}], "
            # f" ss_unc in [{float(jnp.min(ss_unc)):.3f}, {float(jnp.max(ss_unc)):.3f}]"
        )

    ss = EPSILON + jax.nn.softplus(SCALE * ss_unc)
    print("[DEBUG] mean ss: ", float(jnp.mean(ss)))
    # mlflow.log_metric("mean_ss", float(jnp.mean(ss)))
    return mu, ss, all_loss
Beispiel #12
0
 def fun(x):
     return x - jnp.log(jnp.sum(jnp.exp(x)))
Beispiel #13
0
    def fit_model(self, particle_weights: np.ndarray,
                  particles: np.ndarray) -> np.ndarray:
        """Fits a binary model using weighted particles.

    The model will be a sparse lower triangular logistic regression as in
    Procedure 5 from
    https://arxiv.org/pdf/1101.6037.pdf

    Args:
      particle_weights: a np.array<float> of simplicial weights
      particles: np.array<bool>[groups, n_patients]

    Returns:
     A np.array<float>[n_patients, n_patients] model.
    """
        n_groups, n_patients = particles.shape
        model = np.zeros((n_patients, n_patients))
        eps = 1e-5
        # keep track of basic stats
        xbar = (1 - eps) * np.sum(particle_weights[:, np.newaxis] * particles,
                                  axis=0) + eps * 0.5
        xcov = np.matmul(np.transpose(particles),
                         particle_weights[:, np.newaxis] * particles)
        xb1mxb = xbar * (1.0 - xbar)
        cov_matrix = (xcov -
                      xbar[:, np.newaxis] * xbar[np.newaxis, :]) / np.sqrt(
                          xb1mxb[:, np.newaxis] * xb1mxb[np.newaxis, :])

        # TODO(oliviert): turn this into parameters.
        eps = 0.01
        delta = 0.05
        indices_model = np.logical_and(xbar > eps, xbar < 1 - eps)
        indices_single = np.logical_or(xbar <= eps, xbar >= 1 - eps)
        # no regression for first variable
        indices_single = jax.ops.index_update(indices_single, 0, True)
        indices_model = jax.ops.index_update(indices_model, 0, False)

        # look for sparse blocks of variables to regress on others
        if self.sparse_model_lr:
            regressed, regressor = np.where(np.abs(cov_matrix) > delta)
            dic_regressors = collections.defaultdict(list)
            for i, j in zip(regressed, regressor):
                if j < i:
                    dic_regressors[i].append(j)

        # Where there exists cross-correlation we estimate a model
        # TODO(cuturi) : switch to predefined number of regressors (i.e. top k
        # corellated variables. From kth patient we can then jit this regression.
        for i in np.where(indices_model)[0]:
            if self.sparse_model_lr:
                indices_i = dic_regressors[i]
            else:
                indices_i = list(range(i))

            regressors = np.concatenate(
                (particles[:, indices_i], np.ones((n_groups, 1))), axis=-1)
            y = particles[:, i]

            # initialize loop
            # TODO(oliviert): turn those hard coded constants into parameters
            b = np.zeros((regressors.shape[1], ))
            diff = 1e10
            iterations = 0
            reg = .05

            while diff > 1e-2 and iterations < 30:
                iterations += 1
                regressorsb = np.dot(regressors, b)
                p = jax.scipy.special.expit(regressorsb)
                q = p * (1 - p)
                cov = np.matmul(
                    particle_weights[np.newaxis, :] * q[np.newaxis, :] *
                    np.transpose(regressors), regressors)
                cov = cov + reg * np.eye(len(indices_i) + 1)
                c = np.dot(
                    np.transpose(regressors) * particle_weights[np.newaxis, :],
                    q * regressorsb + y - p)
                bnew = np.linalg.solve(cov, c)
                diff = np.sum((bnew - b)**2)
                b = bnew
            # add constant, to list of indices, to be stored in [i,i]
            indices_i.append(i)
            # update line i of model
            model = jax.ops.index_update(
                model, jax.ops.index[i, np.asarray(indices_i)], bnew)

        # Where there are no cross-correlations, or posterior is very peaked,
        # we flip randomly and indvidually
        v = np.zeros((n_patients, ))
        v = jax.ops.index_update(v, jax.ops.index[indices_single],
                                 jax.scipy.special.logit(xbar[indices_single]))
        model = model + np.diag(v)
        self.model = model
def train_model(rand_key,
                network_size,
                lr,
                iters,
                train_input,
                test_input,
                test_mask,
                optimizer,
                ab,
                name=''):
    if ab is None:
        ntk_params = False
    else:
        ntk_params = True
    init_fn, apply_fn, kernel_fn = make_network(*network_size,
                                                ntk_params=ntk_params)

    if ab is None:
        run_model = jit(lambda params, ab, x: np.squeeze(
            apply_fn(params, x[..., None] - .5)))
    else:
        run_model = jit(lambda params, ab, x: np.squeeze(
            apply_fn(params, input_encoder(x, *ab))))
    model_loss = jit(lambda params, ab, x, y: .5 * np.sum(
        (run_model(params, ab, x) - y)**2))
    model_psnr = jit(lambda params, ab, x, y: -10 * np.log10(
        np.mean((run_model(params, ab, x) - y)**2)))
    model_grad_loss = jit(lambda params, ab, x, y: jax.grad(model_loss)
                          (params, ab, x, y))

    opt_init, opt_update, get_params = optimizer(lr)
    opt_update = jit(opt_update)

    if ab is None:
        _, params = init_fn(rand_key, (-1, 1))
    else:
        _, params = init_fn(rand_key,
                            (-1, input_encoder(train_input[0], *ab).shape[-1]))
    opt_state = opt_init(params)

    pred0 = run_model(get_params(opt_state), ab, test_input[0])
    pred0_f = np.fft.fft(pred0)

    train_psnrs = []
    test_psnrs = []
    theories = []
    xs = []
    errs = []
    for i in tqdm(range(iters), desc=name):
        opt_state = opt_update(
            i, model_grad_loss(get_params(opt_state), ab, *train_input),
            opt_state)

        if i % 20 == 0:
            train_psnr = model_psnr(get_params(opt_state), ab, *train_input)
            test_psnr = model_psnr(get_params(opt_state), ab,
                                   test_input[0][test_mask],
                                   test_input[1][test_mask])
            if ab is None:
                train_fx = run_model(get_params(opt_state), ab, train_input[0])
                test_fx = run_model(get_params(opt_state), ab,
                                    test_input[0][test_mask])
                theory = predict_psnr_basic(
                    kernel_fn, train_fx, test_fx,
                    train_input[0][..., None] - .5, train_input[1],
                    test_input[0][test_mask][..., None],
                    test_input[1][test_mask], i * lr)
            else:
                test_x = input_encoder(test_input[0][test_mask], *ab)
                train_x = input_encoder(train_input[0], *ab)

                train_fx = run_model(get_params(opt_state), ab, train_input[0])
                test_fx = run_model(get_params(opt_state), ab,
                                    test_input[0][test_mask])
                theory = predict_psnr_basic(kernel_fn, train_fx, test_fx,
                                            train_x, train_input[1], test_x,
                                            test_input[1][test_mask], i * lr)

            train_psnrs.append(train_psnr)
            test_psnrs.append(test_psnr)
            theories.append(theory)
            pred = run_model(get_params(opt_state), ab, train_input[0])
            errs.append(pred - train_input[1])
            xs.append(i)
    return get_params(opt_state), train_psnrs, test_psnrs, errs, np.array(
        theories), xs
Beispiel #15
0
 def g(x):  # x: i32
     return jnp.sum(2. * f(3 * x, 4. * x.astype("float32")))
 def loss(params, inputs, targets):
   predictions = rnn(params, inputs)
   return np.sum((predictions - targets)**2)
Beispiel #17
0
 def f_jax(*, x):
     return jnp.sum(x)
 def loss(test_params):
   x_final = minimize_structure(test_params)
   return np.sum(np.sin(1.0 - x_final))
Beispiel #19
0
 def f_jax(*, x=(1., 2.)):
     return jnp.sum(x[0]) + 2. * jnp.sum(x[1])
 def f(c, a):
   a1, a2 = a
   c1, c2 = c
   b = np.sum(np.cos(a1)) * np.sum(np.tan(c2 * a2))
   c = c1 * np.sin(np.sum(a1 * a2)), c2 * np.cos(np.sum(a1))
   return c, b
 def cross_entropy(logits, labels):
     return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),
                     axis=-1)
Beispiel #22
0
def temperature(V, mass=1.0):
    """Computes the temperature of a system with some velocities."""
    N, dim = V.shape
    return np.sum(mass * V**2) / (N * dim)
Beispiel #23
0
def neg_log_perplexity(batch, model_predictions):
    """Calculate negative log perplexity."""
    _, targets = batch
    hot_targets = one_hot(targets, model_predictions.shape[-1])
    return np.mean(np.sum(model_predictions * hot_targets, axis=-1))
Beispiel #24
0
def cross_entropy_loss(logits, labels):
    log_softmax_logits = jax.nn.log_softmax(logits)
    num_classes = log_softmax_logits.shape[-1]
    one_hot_labels = common_utils.onehot(labels, num_classes)
    return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size
Beispiel #25
0
def loss(params, batch, model_predict):
    """Calculate loss."""
    inputs, targets = batch
    preds = model_predict(params, inputs)
    return -np.mean(np.sum(preds * one_hot(targets, preds.shape[-1]), axis=-1))
	def initialize(self, p=3, q=3, n = 1, d=2, noise_list = None, c=0, noise_magnitude=0.1, noise_distribution = 'normal'):
		"""
		Description: Randomly initialize the hidden dynamics of the system.
		Args:
			p (int/numpy.ndarray): Autoregressive dynamics. If type int then randomly
				initializes a Gaussian length-p vector with L1-norm bounded by 1.0. 
				If p is a 1-dimensional numpy.ndarray then uses it as dynamics vector.
			q (int/numpy.ndarray): Moving-average dynamics. If type int then randomly
				initializes a Gaussian length-q vector (no bound on norm). If p is a
				1-dimensional numpy.ndarray then uses it as dynamics vector.
			n (int): Dimension of values.
			c (float): Default value follows a normal distribution. The ARMA dynamics 
				follows the equation x_t = c + AR-part + MA-part + noise, and thus tends 
				to be centered around mean c.
		Returns:
			The first value in the time-series
		"""
		self.initialized = True
		self.T = 0
		self.max_T = -1
		self.n = n
		self.d = d
		if type(p) == int:
			phi = random.normal(generate_key(), shape=(p,))
			self.phi = 0.99 * phi / np.linalg.norm(phi, ord=1)
		else:
			self.phi = p
		if type(q) == int:
			self.psi = random.normal(generate_key(), shape=(q,))
		else:
			self.psi = q
		if(type(self.phi) is list):
			self.p = self.phi[0].shape[0]
		else:
			self.p = self.phi.shape[0]
		if(type(self.psi) is list):
			self.q = self.psi[0].shape[0]
		else:
			self.q = self.psi.shape[0]
		self.noise_magnitude, self.noise_distribution = noise_magnitude, noise_distribution
		self.c = random.normal(generate_key(), shape=(self.n,)) if c == None else c
		self.x = random.normal(generate_key(), shape=(self.p, self.n))
		if self.d>1:
			self.delta_i_x = random.normal(generate_key(), shape=(self.d-1, self.n)) 
		else:
			self.delta_i_x = None
		
		self.noise_list = None
		if(noise_list is not None):
			self.noise_list = noise_list
			self.noise = np.array(noise_list[0:self.q])
		elif(noise_distribution == 'normal'):
			self.noise = self.noise_magnitude * random.normal(generate_key(), shape=(self.q, self.n)) 
		elif(noise_distribution == 'unif'):
			self.noise = self.noise_magnitude * random.uniform(generate_key(), shape=(self.q, self.n), \
				minval=-1., maxval=1.)
		
		self.feedback=0.0

		def _step(x, delta_i_x, noise, eps):

			if(type(self.phi) is list):
				x_ar = np.dot(x.T, self.phi[self.T])
			else:
				x_ar = np.dot(x.T, self.phi)

			if(type(self.psi) is list):
				x_ma = np.dot(noise.T, self.psi[self.T])
			else:
				x_ma = np.dot(noise.T, self.psi)
			if delta_i_x is not None:
				x_delta_sum = np.sum(delta_i_x)
			else :
				x_delta_sum = 0.0
			x_delta_new=self.c + x_ar + x_ma + eps
			x_new = x_delta_new + x_delta_sum

			next_x = np.roll(x, self.n) 
			next_noise = np.roll(noise, self.n)
			
			next_x = jax.ops.index_update(next_x, 0, x_delta_new) # equivalent to self.x[0] = x_new
			next_noise = jax.ops.index_update(next_noise, 0, eps) # equivalent to self.noise[0] = eps  
			next_delta_i_x=None
			for i in range(self.d-1):
				if i==0:
					next_delta_i_x=jax.ops.index_update(delta_i_x, i, x_delta_new+delta_i_x[i]) 
				else:
					next_delta_i_x=jax.ops.index_update(delta_i_x, i, next_delta_i_x[i-1]+next_delta_i_x[i]) 
			
			return (next_x, next_delta_i_x, next_noise, x_new)

		self._step = jax.jit(_step)
		if self.delta_i_x is not None:
			x_delta_sum= np.sum(self.delta_i_x)
		else:
			x_delta_sum= 0
		return self.x[0]+x_delta_sum
 def loss(a, b):
   matvec = partial(high_precision_dot, a)
   x = lax.custom_linear_solve(matvec, b, explicit_jacobian_solve)
   return np.sum(x)
Beispiel #28
0
def global_norm(pytree):
    return jnp.sqrt(
        jnp.sum(
            jnp.asarray(
                [jnp.sum(jnp.square(x)) for x in jax.tree_leaves(pytree)])))
 def loss(A):
   def step(x, i):
     return np.matmul(A, x), None
   init_x = np.zeros(A.shape[-1:])
   last_x, _ = lax.scan(step, init_x, np.arange(10))
   return np.sum(last_x)
Beispiel #30
0
def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)
Beispiel #31
0
 def conditional_moments(self, f):
     """
     """
     num_components = int(f.shape[0] / 2)
     subbands, modulators = f[:num_components], self.link_fn(f[num_components:])
     return np.sum(subbands * modulators).reshape(-1, 1), np.array([[self.variance]])