Example #1
0
    'max_pool': jax_max_pool,
    'sum_pool': jax_sum_pool,
    'scan': _jax_scan,
    'cond': lax.cond,
    'lt': lax.lt,
    'stop_gradient': lax.stop_gradient,
    'jit': jax.jit,
    'grad': jax.grad,
    'pmap': jax.pmap,
    'psum': lax.psum,
    'abstract_eval': jax_abstract_eval,
    'random_uniform': jax_random.uniform,
    'random_randint': jax_randint,
    'random_normal': jax_random.normal,
    'random_bernoulli': jax_random.bernoulli,
    'random_get_prng': jax.jit(jax_random.PRNGKey),
    'random_split': jax_random.split,
    'dataset_as_numpy': tfds.as_numpy,
    'device_count': jax.local_device_count,
}


_NUMPY_BACKEND = {
    'name': 'numpy',
    'np': onp,
    'jit': lambda f: f,
    'random_get_prng': lambda seed: None,
    'random_split': lambda prng, num=2: (None,) * num,
    'expit': lambda x: 1. / (1. + onp.exp(-x)),
}
Example #2
0
def test_generic_kmeans():
    from jaxns.prior_transforms import PriorChain, UniformPrior
    from jax import vmap, disable_jit, jit
    import pylab as plt

    data = 'shells'
    if data == 'eggbox':
        def log_likelihood(theta, **kwargs):
            return (2. + jnp.prod(jnp.cos(0.5 * theta))) ** 5

        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 100.

    if data == 'shells':

        def log_likelihood(theta, **kwargs):
            def log_circ(theta, c, r, w):
                return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2))
            w1=w2=jnp.array(0.1)
            r1=r2=jnp.array(2.)
            c1 = jnp.array([0., -4.])
            c2 = jnp.array([0., 4.])
            return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2))


        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 1.

    print("Selecting", jnp.sum(select))
    log_VS = jnp.log(jnp.sum(select)/select.size)
    print("V(S)",jnp.exp(log_VS))

    points = U[select, :]
    sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik))
    plt.colorbar(sc)
    plt.show()
    mask = jnp.ones(points.shape[0], dtype=jnp.bool_)
    K = 18
    with disable_jit():
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS))
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K)
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K)
        # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS)
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        # mu, radii, rotation = ellipsoid_parameters
        K = int(jnp.max(cluster_id)+1)

    mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K))
    radii, rotation = vmap(ellipsoid_params)(C)

    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :], c=plt.cm.jet(i / K))
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / K)))
    plt.xlim(-1,2)
    plt.ylim(-1,2)
    plt.show()
Example #3
0
 def test_scatter_static(self, op):
     values = np.ones((5, 6), dtype=np.float32)
     update = np.float32(6.)
     f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
     self.ConvertAndCompare(f_jax, values, update, with_function=True)
Example #4
0
 def f(x):
     return jax.jit(jnp.exp)(x) + 1.
Example #5
0
 def test_jit(self):
     f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
     self.ConvertAndCompare(f_jax, 0.7)
Example #6
0
 def test_function(self):
     f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
     self.ConvertAndCompare(f_jax, 0.7, with_function=True)
Example #7
0
def pairwise_distances(dist,**arg):
    '''
    d_ij = dist(X_i , Y_j)
    "i,j" are assumed to indicate the data index.
    '''
    return jit(vmap(vmap(partial(dist,**arg),in_axes=(None,0)),in_axes=(0,None)))
Example #8
0
def loss(r_surf, nn, sg, fc):
    NC = fc.shape[1]
    theta = np.linspace(0, 2 * PI, NS + 1)
    l = r(fc, theta)[:, :-1, :]
    dl = r1(fc, theta)[:, :-1, :] * (2 * PI / NS)
    I = np.ones(NC)
    return quadratic_flux(r_surf, I, l, dl, nn, sg)


nn = np.load("nn.npy")
sg = np.load("sg.npy")
r_surf = np.load("r_surf.npy")

loss_partial = partial(loss, r_surf, nn, sg)

loss_and_grad_func = jit(value_and_grad(loss_partial))


def main():

    with tb.open_file("coils.hdf5", "r") as f:
        fc = np.asarray(f.root.coilSeries[:, :, :])

    N = 100
    lr = 0.0000001

    t_init = time.time()

    for n in range(N):
        loss, grad = loss_and_grad_func(fc)
        fc = fc - grad * lr
Example #9
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)

    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, variables

    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(functools.partial(train_step, model),
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=(2, ))

    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              FLAGS.randomized),
                                  axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3, ),
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1
    state = flax.jax_utils.replicate(state)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True
    for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)
        state, stats, keys = train_pstep(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0],
                                      step)
                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(
                    np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(
                    np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() -
                                                     t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec,
                                      step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " +
                      f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                      f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(
                    jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(FLAGS.train_dir,
                                            state_to_save,
                                            int(step),
                                            keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(
                lambda x: x[0], state)).optimizer.target
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, eval_variables),
                test_case["rays"],
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(
                    jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(
                    f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec"
                )
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(FLAGS.max_steps),
                                    keep=100)
Example #10
0
    def search(self,
               controller_id,
               controller_params,
               environment_id,
               environment_params,
               loss,
               search_space,
               trials=None,
               smoothing=10,
               min_steps=100,
               verbose=0):
        """
        Description: Search for optimal controller parameters
        Args:
            controller_id (string): id of controller
            controller_params (dict): initial controller parameters dict (updated by search space)
            environment_id (string): id of environment to try on
            environment_params (dict): environment parameters dict
            loss (function): a function mapping y_pred, y_true -> scalar loss
            search_space (dict): dict mapping parameter names to a finite set of options
            trials (int, None): number of random trials to sample from search space / try all parameters
            smoothing (int): loss computed over smoothing number of steps to decrease variance
            min_steps (int): minimum number of steps that the controller gets to run for
            verbose (int): if 1, print progress and current parameters
        """
        self.controller_id = controller_id
        self.controller_params = controller_params
        self.environment_id = environment_id
        self.environment_params = environment_params
        self.loss = loss

        # store the order to test parameters
        param_list = list(
            itertools.product(*[v for k, v in search_space.items()]))
        index = np.arange(
            len(param_list)
        )  # np.random.shuffle doesn't work directly on non-JAX objects
        shuffled_index = random.shuffle(generate_key(), index)
        param_order = [param_list[i]
                       for i in shuffled_index]  # shuffle order of elements

        # helper controller
        def _update_smoothing(l, val):
            """ update smoothing loss list with new val """
            return jax.ops.index_update(np.roll(l, 1), 0, val)

        self._update_smoothing = jit(_update_smoothing)

        # store optimal params and optimal loss
        optimal_params, optimal_loss = {}, None
        t = 0
        for params in param_order:  # loop over all params in the given order
            t += 1
            curr_params = controller_params.copy()
            curr_params.update(
                {k: v
                 for k, v in zip(search_space.keys(), params)})
            loss = self._run_test(curr_params,
                                  smoothing=smoothing,
                                  min_steps=min_steps,
                                  verbose=verbose)
            if not optimal_loss or loss < optimal_loss:
                optimal_params = curr_params
                optimal_loss = loss
            if t == trials:  # break after trials number of attempts, unless trials is None
                break
        return optimal_params, optimal_loss
Example #11
0
    def __init__(
        self,
        preprocessor,
        sample_network_input: jnp.ndarray,
        network,
        optimizer: optax.GradientTransformation,
        transition_accumulator: Any,
        replay,
        batch_size: int,
        exploration_epsilon: Callable[[int], float],
        min_replay_capacity_fraction: float,
        learn_period: int,
        target_network_update_period: int,
        grad_error_bound: float,
        rng_key,
    ):
        self._preprocessor = preprocessor
        self._replay = replay
        self._transition_accumulator = transition_accumulator
        self._batch_size = batch_size
        self._exploration_epsilon = exploration_epsilon
        self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
        self._learn_period = learn_period
        self._target_network_update_period = target_network_update_period

        # Initialize network parameters and optimizer.
        self._rng_key, network_rng_key = jax.random.split(rng_key)
        self._online_params = network.init(network_rng_key,
                                           sample_network_input[None, ...])
        self._target_params = self._online_params
        self._opt_state = optimizer.init(self._online_params)

        # Other agent state: last action, frame count, etc.
        self._action = None
        self._frame_t = -1  # Current frame index.

        # Define jitted loss, update, and policy functions here instead of as
        # class methods, to emphasize that these are meant to be pure functions
        # and should not access the agent object's state via `self`.

        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key = jax.random.split(rng_key, 3)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).q_values
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).q_values
            td_errors = _batch_q_learning(
                q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                q_target_t,
            )
            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size, )
            loss = jnp.mean(losses)
            return loss

        def update(rng_key, opt_state, online_params, target_params,
                   transitions):
            """Computes learning update from batch of replay transitions."""
            rng_key, update_key = jax.random.split(rng_key)
            d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                                transitions, update_key)
            updates, new_opt_state = optimizer.update(d_loss_d_params,
                                                      opt_state)
            new_online_params = optax.apply_updates(online_params, updates)
            return rng_key, new_opt_state, new_online_params

        self._update = jax.jit(update)

        def select_action(rng_key, network_params, s_t, exploration_epsilon):
            """Samples action from eps-greedy policy wrt Q-values at given state."""
            rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
            q_t = network.apply(network_params, apply_key,
                                s_t[None, ...]).q_values[0]
            a_t = rlax.epsilon_greedy().sample(policy_key, q_t,
                                               exploration_epsilon)
            return rng_key, a_t

        self._select_action = jax.jit(select_action)
Example #12
0
 def testJitComputationNaN(self):
     A = jnp.array(0.)
     with self.assertRaises(FloatingPointError):
         ans = jax.jit(lambda x: 0. / x)(A)
         ans.block_until_ready()
Example #13
0
 def testJitComputationNoNaN(self):
     A = jnp.array([[1., 2.], [2., 3.]])
     ans = jax.jit(jnp.tanh)(A)
     ans.block_until_ready()
Example #14
0
def fori_collect(lower,
                 upper,
                 body_fun,
                 init_val,
                 transform=identity,
                 progbar=True,
                 return_last_val=False,
                 collection_size=None,
                 jit_model=True,
                 tqdm_position=0,
                 **progbar_opts):
    """
    This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
    effect of collecting values from the loop body. In addition, this allows for
    post-processing of these samples via `transform`, and progress bar updates.
    Note that, `progbar=False` will be faster, especially when collecting a
    lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`.

    :param int lower: the index to start the collective work. In other words,
        we will skip collecting the first `lower` values.
    :param int upper: number of times to run the loop body.
    :param body_fun: a callable that takes a collection of
        `np.ndarray` and returns a collection with the same shape and
        `dtype`.
    :param init_val: initial value to pass as argument to `body_fun`. Can
        be any Python collection type containing `np.ndarray` objects.
    :param transform: a callable to post-process the values returned by `body_fn`.
    :param progbar: whether to post progress bar updates.
    :param bool return_last_val: If `True`, the last value is also returned.
        This has the same type as `init_val`.
    :param int collection_size: Size of the returned collection. If not specified,
        the size will be ``upper - lower``. If the size is larger than
        ``upper - lower``, only the top ``upper - lower`` entries will be non-zero.
    :param `**progbar_opts`: optional additional progress bar arguments. A
        `diagnostics_fn` can be supplied which when passed the current value
        from `body_fun` returns a string that is used to update the progress
        bar postfix. Also a `progbar_desc` keyword argument can be supplied
        which is used to label the progress bar.
    :return: collection with the same type as `init_val` with values
        collected along the leading axis of `np.ndarray` objects.
    """
    assert lower <= upper
    collection_size = upper - lower if collection_size is None else collection_size
    assert collection_size >= upper - lower
    init_val_flat, unravel_fn = ravel_pytree(transform(init_val))

    @cached_by(fori_collect, body_fun, transform)
    def _body_fn(i, vals):
        val, collection, lower_idx = vals
        val = body_fun(val)
        i = np.where(i >= lower_idx, i - lower_idx, 0)
        collection = ops.index_update(collection, i,
                                      ravel_pytree(transform(val))[0])
        return val, collection, lower_idx

    collection = np.zeros((collection_size, ) + init_val_flat.shape)
    if not progbar:
        last_val, collection, _ = fori_loop(0, upper, _body_fn,
                                            (init_val, collection, lower))
    else:
        diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
        progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '')

        vals = (init_val, collection, device_put(lower))
        if upper == 0:
            # special case, only compiling
            if jit_model:
                jit(_body_fn)(0, vals)
            else:
                _body_fn(0, vals)
        else:
            with tqdm.trange(upper, position=tqdm_position) as t:
                for i in t:
                    if jit_model:
                        vals = jit(_body_fn)(i, vals)
                    else:
                        vals = _body_fn(i, vals)
                    t.set_description(progbar_desc(i), refresh=False)
                    if diagnostics_fn:
                        t.set_postfix_str(diagnostics_fn(vals[0]),
                                          refresh=False)

        last_val, collection, _ = vals

    unravel_collection = vmap(unravel_fn)(collection)
    return (unravel_collection,
            last_val) if return_last_val else unravel_collection
Example #15
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: Callable[[jnp.ndarray], jnp.ndarray],
        num_ensemble: int,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: optix.InitUpdate,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = 1,
    ):
        # Transform the (impure) network into a pure function.
        network = hk.without_apply_rng(hk.transform(network, apply_rng=True))

        # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Q-learning loss with added reward noise + half-in bootstrap."""
            o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            r_t += noise_scale * z_t
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(m_t * td_error**2)

        # Define update function for each member of ensemble..
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Does a step of SGD for the whole ensemble over `transitions`."""

            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize parameters and optimizer state for an ensemble of Q-networks.
        rng = hk.PRNGSequence(seed)
        dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_target_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_opt_state = [optimizer.init(p) for p in initial_params]

        # Internalize state.
        self._ensemble = [
            TrainingState(p, tp, o, step=0) for p, tp, o in zip(
                initial_params, initial_target_params, initial_opt_state)
        ]
        self._forward = jax.jit(network.apply)
        self._sgd_step = sgd_step
        self._num_ensemble = num_ensemble
        self._optimizer = optimizer
        self._replay = replay.Replay(capacity=replay_capacity)

        # Agent hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._mask_prob = mask_prob

        # Agent state.
        self._active_head = self._ensemble[0]
        self._total_steps = 0
Example #16
0
  g0 = np.where(hps['do_tanh_latents'], np.tanh(g0), g0)
  ii_txi = np.where(hps['do_tanh_latents'], np.tanh(ii_txi), ii_txi)  
  ib = np.where(hps['do_tanh_latents'], np.tanh(ib), ib)
  
  prior_sample = {'g0' : g0, 'ib' : ib, 'ii_t' : ii_txi}
  decodes = decode_prior(params, hps, keys[1], prior_sample)

  return {'factor_t' : decodes['f'],
          'g0' : g0, 'gen_t' : decodes['g'],
          'ib' : decodes['ib'], 
          'ib_t' : decodes['ib'],
          'ii_t' : ii_txi,           
          'lograte_t' : decodes['lograte']}


encode_jit = jit(encode)
decode_jit = jit(decode, static_argnums=(1,))
forward_pass_jit = jit(forward_pass, static_argnums=(1,))


# Batching accomplished by vectorized mapping. We simultaneously map over random
# keys for forward-pass randomness and inputs for batching.
batch_forward_pass = vmap(forward_pass, in_axes=(None, None, 0, 0, None))
# These shenanigans are thanks to vmap complaining about the decompose_latent
batch_forward_pass_prior = lambda params, hps, keys : vmap(lambda key: forward_pass_from_prior(params, hps, key), in_axes=(0,))(keys)


def losses(params, hps, key, x_bxt, kl_scale, keep_rate):
  """Compute the training loss of the LFADS autoencoder

  Arguments:
Example #17
0
 def testScalarCastInsideJitWorks(self):
     # jnp.int32(tracer) should work.
     self.assertEqual(jnp.int32(101),
                      jax.jit(lambda x: jnp.int32(x))(jnp.float32(101.4)))
    def initialize(self,
                   p=3,
                   q=3,
                   n=1,
                   d=2,
                   noise_list=None,
                   c=0,
                   noise_magnitude=0.1,
                   noise_distribution='normal'):
        """
        Description: Randomly initialize the hidden dynamics of the system.
        Args:
            p (int/numpy.ndarray): Autoregressive dynamics. If type int then randomly
                initializes a Gaussian length-p vector with L1-norm bounded by 1.0. 
                If p is a 1-dimensional numpy.ndarray then uses it as dynamics vector.
            q (int/numpy.ndarray): Moving-average dynamics. If type int then randomly
                initializes a Gaussian length-q vector (no bound on norm). If p is a
                1-dimensional numpy.ndarray then uses it as dynamics vector.
            n (int): Dimension of values.
            c (float): Default value follows a normal distribution. The ARMA dynamics 
                follows the equation x_t = c + AR-part + MA-part + noise, and thus tends 
                to be centered around mean c.
        Returns:
            The first value in the time-series
        """
        self.initialized = True
        self.T = 0
        self.max_T = -1
        self.n = n
        self.d = d
        if type(p) == int:
            phi = random.normal(generate_key(), shape=(p, ))
            self.phi = 0.99 * phi / np.linalg.norm(phi, ord=1)
        else:
            self.phi = p
        if type(q) == int:
            self.psi = random.normal(generate_key(), shape=(q, ))
        else:
            self.psi = q
        if (type(self.phi) is list):
            self.p = self.phi[0].shape[0]
        else:
            self.p = self.phi.shape[0]
        if (type(self.psi) is list):
            self.q = self.psi[0].shape[0]
        else:
            self.q = self.psi.shape[0]
        self.noise_magnitude, self.noise_distribution = noise_magnitude, noise_distribution
        self.c = random.normal(generate_key(),
                               shape=(self.n, )) if c == None else c
        self.x = random.normal(generate_key(), shape=(self.p, self.n))
        if self.d > 1:
            self.delta_i_x = random.normal(generate_key(),
                                           shape=(self.d - 1, self.n))
        else:
            self.delta_i_x = None

        self.noise_list = None
        if (noise_list is not None):
            self.noise_list = noise_list
            self.noise = np.array(noise_list[0:self.q])
        elif (noise_distribution == 'normal'):
            self.noise = self.noise_magnitude * random.normal(
                generate_key(), shape=(self.q, self.n))
        elif (noise_distribution == 'unif'):
            self.noise = self.noise_magnitude * random.uniform(generate_key(), shape=(self.q, self.n), \
                minval=-1., maxval=1.)

        self.feedback = 0.0

        def _step(x, delta_i_x, noise, eps):

            if (type(self.phi) is list):
                x_ar = np.dot(x.T, self.phi[self.T])
            else:
                x_ar = np.dot(x.T, self.phi)

            if (type(self.psi) is list):
                x_ma = np.dot(noise.T, self.psi[self.T])
            else:
                x_ma = np.dot(noise.T, self.psi)
            if delta_i_x is not None:
                x_delta_sum = np.sum(delta_i_x)
            else:
                x_delta_sum = 0.0
            x_delta_new = self.c + x_ar + x_ma + eps
            x_new = x_delta_new + x_delta_sum

            next_x = np.roll(x, self.n)
            next_noise = np.roll(noise, self.n)

            next_x = jax.ops.index_update(
                next_x, 0, x_delta_new)  # equivalent to self.x[0] = x_new
            next_noise = jax.ops.index_update(
                next_noise, 0, eps)  # equivalent to self.noise[0] = eps
            next_delta_i_x = None
            for i in range(d - 1):
                if i == 0:
                    next_delta_i_x = jax.ops.index_update(
                        delta_i_x, i, x_delta_new + delta_i_x[i])
                else:
                    next_delta_i_x = jax.ops.index_update(
                        delta_i_x, i,
                        next_delta_i_x[i - 1] + next_delta_i_x[i])

            return (next_x, next_delta_i_x, next_noise, x_new)

        self._step = jax.jit(_step)
        if self.delta_i_x is not None:
            x_delta_sum = np.sum(self.delta_i_x)
        else:
            x_delta_sum = 0
        return self.x[0] + x_delta_sum
Example #19
0
    x = jnp.tanh(jnp.matmul(x, w[0]))
    x = sigmoid(jnp.matmul(x, w[1]))

    return x


@jit
def get_loss(x, w, y_tgts):

    y_pred = forward(x, w)

    return ce_loss(y_tgts, y_pred)


get_grad = grad(get_loss, argnums=(1))
jit_grad = jit(get_grad)

if __name__ == "__main__":

    x = np.random.randn(1024, 128)
    y_tgts = np.random.randint(2, size=(1024, 1))

    w0 = 1e-2 * np.random.randn(128, 128)
    w1 = 1e-2 * np.random.randn(128, 1)
    w = [w0, w1]

    t0 = time.time()
    for ii in range(10000):

        my_grad = get_grad(x, w, y_tgts)
Example #20
0
 def test_nested_jit(self):
     f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
     f_tf = jax_to_tf.convert(f_jax)
     np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))
## define jax friendly function for simulating the system during mpc
def simulate(xt, u, w, theta):
    a = theta['a']
    b = theta['b']
    [o, M, N] = w.shape
    x = jnp.zeros((o, M, N + 1))
    x = index_update(x, index[:, :, 0], xt)
    for k in range(N):
        x = index_update(x, index[:, :, k + 1],
                         a * x[:, :, k] + b * u[:, k] + w[:, :, k])
    return x[:, :, 1:]


# define MPC cost, gradient and hessian function
cost = jit(
    log_barrier_cost,
    static_argnums=(11, 12, 13, 14,
                    15))  # static argnums means it will recompile if N changes
gradient = jit(
    grad(log_barrier_cost, argnums=0), static_argnums=(11, 12, 13, 14, 15)
)  # get compiled function to return gradients with respect to z (uc, s)
hessian = jit(jacfwd(jacrev(log_barrier_cost, argnums=0)),
              static_argnums=(11, 12, 13, 14, 15))

xt_est_save = np.zeros((1, M, T))
a_est_save = np.zeros((M, T))
b_est_save = np.zeros((M, T))
q_est_save = np.zeros((M, T))
r_est_save = np.zeros((M, T))
mpc_result_save = []
### SIMULATE SYSTEM AND PERFORM MPC CONTROL
for t in tqdm(range(T),
Example #22
0
 def test_gather(self):
     values = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
     indices = np.array([0, 1], dtype=np.int32)
     for axis in (0, 1):
         f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
         self.ConvertAndCompare(f_jax, values, indices, with_function=True)
    # As this is tutorial code, we're passing everything around.
    return {
        'xenc_t': xenc_t,
        'ic_mean': ic_mean,
        'ic_logvar': ic_logvar,
        'ii_t': ii_t,
        'c_t': c_t,
        'ii_mean_t': ii_mean_t,
        'ii_logvar_t': ii_logvar_t,
        'gen_t': gen_t,
        'factor_t': factor_t,
        'lograte_t': lograte_t
    }


lfads_encode_jit = jit(lfads_encode)
lfads_decode_jit = jit(lfads_decode, static_argnums=(1, ))

# Batching accomplished by vectorized mapping.
# We simultaneously map over random keys for forward-pass randomness
# and inputs for batching.
batch_lfads = vmap(lfads, in_axes=(None, None, 0, 0, None))


def lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate):
    """Compute the training loss of the LFADS autoencoder

  Arguments:
    params: a dictionary of LFADS parameters
    lfads_hps: a dictionary of LFADS hyperparameters
    key: random.PRNGKey for random bits
Example #24
0
 def test_gather_rank_change(self):
     params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
     indices = jnp.array([[1, 1, 2], [0, 1, 0]])
     f_jax = jax.jit(lambda i: params[i])
     self.ConvertAndCompare(f_jax, indices, with_function=True)
Example #25
0
    def testInfeed(self):
        devices = np.array(jax.local_devices())
        nr_devices = len(devices)
        shape = (nr_devices * 3, nr_devices * 5)

        def f_for_jit(x):
            token = lax.create_token(x)
            (y, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))
            (z, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))
            (w, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))

            return x + y + z + w

        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
        y = x * 2.
        z = x * 3.
        w = x * 4.

        # Transfer data to infeed before executing the function. For GPUs, the
        # execution of the compiled function is blocking, so transferring data
        # to infeed before executing ensures that the execution does not deadlock
        # waiting for the infeed data.
        logging.info('Transfering to infeed for the jit call')
        d = devices[0]
        d.transfer_to_infeed((y, ))
        d.transfer_to_infeed((z, ))
        d.transfer_to_infeed((w, ))

        # JIT
        logging.info('Making jit call')
        res0 = jax.jit(f_for_jit)(x)
        self.assertAllClose(res0, x + y + z + w, check_dtypes=True)

        # PJIT
        def f_for_pjit(x):
            token = lax.create_token(x)
            # A replicated infeed
            (y, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(None, ))
            # An infeed sharded on first axis
            (z, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(P(nr_devices, 1), ))
            # An infeed sharded on second axis
            (w, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(P(1, nr_devices), ))
            return x + y + z + w

        logging.info('Transfering to infeed for the pjit call')
        for didx, d in enumerate(devices):
            # Transfer the whole array to all devices for replicated.
            d.transfer_to_infeed((y, ))
            # For sharded infeed, transfer only the needed slices to each device.
            d.transfer_to_infeed((z[3 * didx:3 * didx + 3, :]))
            d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5], ))

        with mesh(devices, ['d']):
            logging.info('Making pjit call')
            res = pjit(f_for_pjit,
                       in_axis_resources=(P('d'), ),
                       out_axis_resources=P('d'))(x)

        self.assertAllClose(res0, res, check_dtypes=True)
Example #26
0
 def test_large_device_constant(self):
     ans = jit(lambda x: 2 * x)(np.ones(int(2e6)))  # doesn't crash
     self.assertAllClose(ans, onp.ones(int(2e6)) * 2., check_dtypes=False)
Example #27
0
    def function_type2(self):
        if not hasattr(self, '_function_type2'):

            def func(params, state, rng, S, is_training):
                rngs = hk.PRNGSequence(rng)
                new_state = dict(state)

                # s' ~ p(s'|s,.)  # note: S_next is replicated, one for each (discrete) action
                if is_stochastic(self.p):
                    dist_params_rep, new_state['p'] = self.p.function_type2(
                        params['p'], state['p'], next(rngs), S, is_training)
                    dist_params_rep = jax.tree_map(self._reshape_to_replicas,
                                                   dist_params_rep)
                    S_next_rep = self.p.proba_dist.mean(dist_params_rep)
                else:
                    S_next_rep, new_state['p'] = self.p.function_type2(
                        params['p'], state['p'], next(rngs), S, is_training)
                    S_next_rep = jax.tree_map(self._reshape_to_replicas,
                                              S_next_rep)

                # r ~ p(r|s,a)  # note: R is replicated, one for each (discrete) action
                if is_stochastic(self.r):
                    dist_params_rep, new_state['r'] = self.r.function_type2(
                        params['r'], state['r'], next(rngs), S, is_training)
                    dist_params_rep = jax.tree_map(self._reshape_to_replicas,
                                                   dist_params_rep)
                    R_rep = self.r.proba_dist.mean(dist_params_rep)
                    R_rep = self.r.proba_dist.postprocess_variate(
                        next(rngs), R_rep, batch_mode=True)
                else:
                    R_rep, new_state['r'] = self.r.function_type2(
                        params['r'], state['r'], next(rngs), S, is_training)
                    R_rep = jax.tree_map(self._reshape_to_replicas, R_rep)

                # v(s')  # note: since the input S_next is replicated, so is the output V
                if is_stochastic(self.v):
                    dist_params_rep, new_state['v'] = self.v.function(
                        params['v'], state['v'], next(rngs), S_next_rep,
                        is_training)
                    V_rep = self.v.proba_dist.mean(dist_params_rep)
                    V_rep = self.v.proba_dist.postprocess_variate(
                        next(rngs), V_rep, batch_mode=True)
                else:
                    V_rep, new_state['v'] = self.v.function(
                        params['v'], state['v'], next(rngs), S_next_rep,
                        is_training)

                # q = r + γ v(s')
                f, f_inv = self.value_transform
                Q_rep = f(R_rep + params['gamma'] * f_inv(V_rep))

                # reshape from (batch x num_actions, *) to (batch, num_actions, *)
                Q_s = self._reshape_from_replicas(Q_rep)
                assert Q_s.ndim == 2, f"bad shape: {Q_s.shape}"
                assert Q_s.shape[
                    1] == self.action_space.n, f"bad shape: {Q_s.shape}"

                new_state = hk.data_structures.to_immutable_dict(new_state)
                assert jax.tree_structure(new_state) == jax.tree_structure(
                    state)

                return Q_s, new_state

            self._function_type2 = jax.jit(func, static_argnums=(4, ))

        return self._function_type2
Example #28
0
def test_xlog1py(x, y, jit_fn):
    fn = xlog1py if not jit_fn else jit(xlog1py)
    assert_allclose(fn(x, y), osp_special.xlog1py(x, y))
Example #29
0
    y = jnp.concatenate([samples.T, jnp.ones(shape=(1, N))], axis=0)
    eta = jnp.append(v, 0)

    # compute logq
    data_part = jnp.einsum('ik,jkh,hi->ij', y.T, jnp.linalg.inv(S), y)
    logdetS = jnp.linalg.slogdet(S)[1]
    log_q = -0.5 * (data_part + logdetS)

    # probability of belonging to each cluster
    alpha = jnp.exp(eta)
    alpha = alpha / jnp.sum(alpha)

    return jnp.sum(logsumexp(jnp.log(alpha) + log_q, axis=1))


cost = jit(lambda x: -costfunction(x))
gr_cost = jit(grad(cost))
f_tru = cost([true_S, true_eta])
g_tru = gr_cost([true_S, true_eta])
print(f_tru)

f_emp = cost([emp_S, emp_eta])
g_emp = gr_cost([emp_S, emp_eta])
print(f_emp)

man = Product([SPD(D + 1, M), Euclidean(M - 1)])

njobs = 10
rng, key = random.split(rng)
rng, *key = random.split(rng, njobs + 1)
Example #30
0
import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial

def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(lambda params, inputs, targets:  # fast per-example gradients
                  vmap(partial(grad_fun, params), inputs, targets))
 def test_named_call_partial_function(self):
     f = stateful.named_call(lambda x, y: y if x else None)
     f = jax.jit(functools.partial(f, True))
     out = f(5)
     self.assertEqual(out, 5)