Beispiel #1
0
def fori_loop(lower, upper, body_fun, init_val):
    if _DISABLE_CONTROL_FLOW_PRIM:
        val = init_val
        for i in range(int(lower), int(upper)):
            val = body_fun(i, val)
        return val
    else:
        return lax.fori_loop(lower, upper, body_fun, init_val)
Beispiel #2
0
    def epoch_train(opt_state, rng):
        def body_fn(i, val):
            loss_sum, opt_state, rng = val
            rng, batch = binarize(rng, train_fetch(i, train_idx)[0])
            loss, opt_state, rng = svi_update(i, rng, opt_state, (batch,), (batch,),)
            loss_sum += loss
            return loss_sum, opt_state, rng

        return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng))
Beispiel #3
0
    def sum_first_n(arr, num):
      def body_fun(i, state):
        arr, total = state['arr'], state['total']
        arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
        return {'arr': arr, 'total': lax.add(total, arr_i)}

      init_val = {'arr': arr, 'total': 0.}
      out_val = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
      return out_val['total']
Beispiel #4
0
    def sum_first_n(arr, num):
      def body_fun(i, state):
        arr, total, _ = state
        arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
        return (arr, lax.add(total, arr_i), ())

      init_val = (arr, 0., ())
      _, tot, _ = lax.fori_loop(0, lax.min(arr.shape[0], num), body_fun, init_val)
      return tot
def simulate(t, state, dt, space, t_h, parameters, goal, N_step, dx):
    # for i in range(N_step):
    #     t, state = step(t, state, dt, space, t_h, parameters)
    body_fun = lambda i, val: step(*val, dt, space, t_h, parameters)
    t, state = fori_loop(0, N_step, body_fun, (t, state))
    occ = occupation(state, goal, dx)
    occ = (occ * np.conj(occ)).real
    print(occ)
    return occ
Beispiel #6
0
def iterate_leapfrogs(theta, phi, eps, M, L, grad_fun):

    init_val = {"theta": theta, "phi": phi, "eps": eps, "M": M}

    to_iterate = lambda i, val: body_fun_leapfrog(i, val, grad_fun)

    final_res = fori_loop(0, L, to_iterate, init_val)

    return final_res["theta"], final_res["phi"]
Beispiel #7
0
def nloglik(X):
    y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0).T

    def body(i, ll):
        Si = jnp.outer(y[i], y[i])
        return ll + jnp.log(1 + jnp.trace(jnp.linalg.solve(X, Si)))

    llik = - (df + p) * 0.5 * fori_loop(0, N, body, 0.)
    return llik - 0.5 * N * jnp.linalg.slogdet(X)[1]
Beispiel #8
0
    def eval_test(svi_state, batchifier_state, num_batch):
        def body_fn(i, loss_sum):
            batch = test_fetch(i, batchifier_state)
            loss = svi.evaluate(svi_state, *batch)
            loss_sum += loss / (args.num_samples * num_batch)

            return loss_sum

        return lax.fori_loop(0, num_batch, body_fn, 0.)
Beispiel #9
0
    def epoch_train(svi_state, batchifier_state, num_batch):
        def body_fn(i, val):
            svi_state, loss = val
            batch = train_fetch(i, batchifier_state)
            svi_state, batch_loss = svi.update(svi_state, *batch)
            loss += batch_loss / (args.num_samples * num_batch)
            return svi_state, loss

        return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.))
Beispiel #10
0
def make_distanceMatrix(points, idx, distance, n):
    distmatrix = jnp.zeros(shape=(n, n))

    def bodyfun(i, dists):
        j, k = idx[i]
        return dists.at[j, k].set(distance(points[j], points[k]))

    distmatrix = fori_loop(0, len(idx), bodyfun, distmatrix)
    return distmatrix
Beispiel #11
0
def test_mnist_data_load():
    def mean_pixels(i, mean_pix):
        batch, _ = fetch(i, idx)
        return mean_pix + np.sum(batch) / batch.size

    init, fetch = load_dataset(MNIST, batch_size=128, split='train')
    num_batches, idx = init()
    assert lax.fori_loop(0, num_batches, mean_pixels,
                         np.float32(0.)) / num_batches < 0.15
Beispiel #12
0
    def run_epoch(rng, opt_state):
        def body_fun(i, opt_state):
            elbo_rng, data_rng = random.split(random.fold_in(rng, i))
            batch = binarize_batch(data_rng, i, train_images)
            loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
            g = grad(loss)(get_params(opt_state))
            return opt_update(i, g, opt_state)

        return lax.fori_loop(0, num_batches, body_fun, opt_state)
    def p_sample_loop(self,
                      model_fn,
                      *,
                      shape,
                      rng,
                      num_timesteps=None,
                      return_x_init=False):
        """Ancestral sampling."""
        init_rng, body_rng = jax.random.split(rng)
        del rng

        noise_shape = shape + (self.num_pixel_vals, )

        def body_fun(i, x):
            t = jnp.full([shape[0]], self.num_timesteps - 1 - i)
            x, _ = self.p_sample(model_fn=model_fn,
                                 x=x,
                                 t=t,
                                 noise=jax.random.uniform(jax.random.fold_in(
                                     body_rng, i),
                                                          shape=noise_shape))
            return x

        if self.transition_mat_type in ['gaussian', 'uniform']:
            # Stationary distribution is a uniform distribution over all pixel values.
            x_init = jax.random.randint(init_rng,
                                        shape=shape,
                                        minval=0,
                                        maxval=self.num_pixel_vals)

        elif self.transition_mat_type == 'absorbing':
            # Stationary distribution is a kronecker delta distribution
            # with all its mass on the absorbing state.
            # Absorbing state is located at rgb values (128, 128, 128)
            x_init = jnp.full(shape=shape,
                              fill_value=self.num_pixel_vals // 2,
                              dtype=jnp.int32)
        else:
            raise ValueError(
                f"transition_mat_type must be 'gaussian', 'uniform', 'absorbing' "
                f", but is {self.transition_mat_type}")

        del init_rng

        if num_timesteps is None:
            num_timesteps = self.num_timesteps

        final_x = lax.fori_loop(lower=0,
                                upper=num_timesteps,
                                body_fun=body_fun,
                                init_val=x_init)
        assert final_x.shape == shape
        if return_x_init:
            return x_init, final_x
        else:
            return final_x
Beispiel #14
0
def _outer_fun(n, vals):

    vals["cur_total"] = jnp.zeros_like(vals["es"][0])
    vals["cur_n"] = n

    result = fori_loop(1, n + 1, _inner_fun, vals)

    vals["es"] = index_update(vals["es"], n, result["cur_total"] / n)

    return vals
Beispiel #15
0
    def epoch_train(svi_state, rng_key):
        def body_fn(i, val):
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fn, (0., svi_state))
Beispiel #16
0
 def run_epoch(rng, opt_state):
   def body_fun(i, rng__opt_state__images):
     (rng, opt_state, images) = rng__opt_state__images
     rng, elbo_rng, data_rng = random.split(rng, 3)
     batch = binarize_batch(data_rng, i, images)
     loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
     g = grad(loss)(minmax.get_params(opt_state))
     return rng, opt_update(i, g, opt_state), images
   init_val = rng, opt_state, train_images
   _, opt_state, _ =  lax.fori_loop(0, num_batches, body_fun, init_val)
   return opt_state
Beispiel #17
0
def backward_pass(x_trj, u_trj, regu, target):
    k_trj = np.empty_like(u_trj)
    K_trj = np.empty((TIME_STEPS-1, N_U, N_X))
    expected_cost_redu = 0.
    V_x, V_xx = derivative_final(x_trj[-1], target)
     
    V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target = lax.fori_loop(
        0, TIME_STEPS-1, backward_pass_looper, [V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target]
    )
        
    return k_trj, K_trj, expected_cost_redu
Beispiel #18
0
    def _sample(self,key,n_samps,factors, bits_to_fix = -1, values_to_fix = -1):
        """generate samples from a distributions with a given set of factors
        
        Parameters
        ----------
        key : jax.random.PRNGKey
            jax random number generator 
        n_samps : int
            number of samples to generate
        factors : array_like
            factors of the distribution
        
        Returns
        -------
        array_like
            samples from the model
        """
        state = random.randint(key,minval=0,maxval=2, shape=(self.N,))
        unifs = random.uniform(key, shape=(n_samps*self.N,))
        all_states = np.zeros((n_samps,self.N))

        if bits_to_fix != -1:
            condition = True
            bits_to_keep = np.array([x for x in range(self.N) if x not in bits_to_fix])
            N = bits_to_keep.size
            values_to_fix = np.array(values_to_fix)
        else:
            condition = False
            bits_to_keep = np.arange(self.N)  
            N = self.N
        # @jit
        # def run_mh(j, loop_carry):
        #     state, all_states = loop_carry
        #     all_states = index_update(all_states,j//self.N,state)  # a bit wasteful
        #     state_flipped = index_update(state,j%self.N,1-state[j%self.N])
        #     dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state)
        #     accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
        #     state = np.where(accept, state_flipped, state)
        #     return state, all_states
        
        @jit
        def run_mh(j, loop_carry):
            state, all_states = loop_carry
            if condition:
                state = index_update(state, bits_to_fix, values_to_fix)
            all_states = index_update(all_states,j//N,state)  # a bit wasteful
            state_flipped = index_update(state,bits_to_keep[j%N],1-state[bits_to_keep[j%N]])
            dE = self.calc_e(factors,state_flipped)-self.calc_e(factors,state)
            accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
            state = np.where(accept, state_flipped, state)
            return state, all_states

        all_states = fori_loop(0, n_samps * N, run_mh, (state, all_states))
        return all_states[1]
Beispiel #19
0
        def f_op_jax():
            arr = jnp.zeros(5)

            def loop_body(i, acc_arr):
                arr1 = acc_arr.at[i].set(acc_arr[i] + 2.)
                return lax.cond(i % 2 == 0, arr1,
                                lambda arr1: arr1.at[i].set(arr1[i] + 1.),
                                arr1, lambda arr1: arr1)

            arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
            return arr
Beispiel #20
0
 def f_op_jax():
   arr = jnp.zeros(5)
   def loop_body(i, acc_arr):
     arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
     return lax.cond(i % 2 == 0,
                     arr1,
                     lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.),
                     arr1,
                     lambda arr1: arr1)
   arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
   return arr
Beispiel #21
0
    def eval_test(svi_state: SVIState, rng_key: np.ndarray) -> jnp.ndarray:
        def body_fun(i: jnp.ndarray, loss_sum: jnp.ndarray) -> jnp.ndarray:
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.0)
        loss = loss / num_test
        return loss
Beispiel #22
0
    def eval_test(opt_state, rng):
        def body_fun(i, val):
            loss_sum, rng = val
            rng, = random.split(rng, 1)
            rng, batch = binarize(rng, test_fetch(i, test_idx)[0])
            loss = svi_eval(rng, opt_state, (batch,), (batch,)) / len(batch)
            loss_sum += loss
            return loss_sum, rng

        loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng))
        loss = loss / num_test
        return loss
Beispiel #23
0
def forward_pass(x_trj, u_trj, k_trj, K_trj):
    u_trj = np.arcsin(np.sin(u_trj))
    
    x_trj_new = np.empty_like(x_trj)
    x_trj_new = jax.ops.index_update(x_trj_new, jax.ops.index[0], x_trj[0])
    u_trj_new = np.empty_like(u_trj)
    
    x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new = lax.fori_loop(
        0, TIME_STEPS-1, forward_pass_looper, [x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new]
    )

    return x_trj_new, u_trj_new
Beispiel #24
0
    def eval_test(svi_state, rng_key):
        def body_fun(i, loss_sum):
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
            # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
            loss = svi.evaluate(svi_state, batch) / len(batch)
            loss_sum += loss
            return loss_sum

        loss = lax.fori_loop(0, num_test, body_fun, 0.)
        loss = loss / num_test
        return loss
Beispiel #25
0
    def epoch_train(svi_state: SVIState, rng_key: np.ndarray) -> Tuple[jnp.ndarray, SVIState]:
        def body_fun(
            i: jnp.ndarray, val: Tuple[jnp.ndarray, SVIState]
        ) -> Tuple[jnp.ndarray, SVIState]:
            loss_sum, svi_state = val
            rng_key_binarize = random.fold_in(rng_key, i)
            batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
            svi_state, loss = svi.update(svi_state, batch)
            loss_sum += loss
            return loss_sum, svi_state

        return lax.fori_loop(0, num_train, body_fun, (0.0, svi_state))
Beispiel #26
0
def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
    """Apply the Threefry 2x32 hash.

  Args:
    keypair: a pair of 32bit unsigned integers used for the key.
    count: an array of dtype uint32 used for the counts.

  Returns:
    An array of dtype uint32 with the same shape as `count`.
  """
    x = [x1, x2]

    rotations = [
        np.array([13, 15, 26, 6], dtype=np.uint32),
        np.array([17, 29, 16, 24], dtype=np.uint32)
    ]
    ks = [key1, key2, key1 ^ key2 ^ np.uint32(0x1BD11BDA)]

    x[0] = x[0] + ks[0]
    x[1] = x[1] + ks[1]

    if use_rolled_loops:
        x, _, _ = lax.fori_loop(0, 5, rolled_loop_step,
                                (x, rotate_list(ks), rotations))

    else:
        for r in rotations[0]:
            x = apply_round(x, r)
        x[0] = x[0] + ks[1]
        x[1] = x[1] + ks[2] + np.uint32(1)

        for r in rotations[1]:
            x = apply_round(x, r)
        x[0] = x[0] + ks[2]
        x[1] = x[1] + ks[0] + np.uint32(2)

        for r in rotations[0]:
            x = apply_round(x, r)
        x[0] = x[0] + ks[0]
        x[1] = x[1] + ks[1] + np.uint32(3)

        for r in rotations[1]:
            x = apply_round(x, r)
        x[0] = x[0] + ks[1]
        x[1] = x[1] + ks[2] + np.uint32(4)

        for r in rotations[0]:
            x = apply_round(x, r)
        x[0] = x[0] + ks[2]
        x[1] = x[1] + ks[0] + np.uint32(5)

    return tuple(x)
Beispiel #27
0
def test_rans_lax_fori_loop():
    size = 3
    tail_capacity = 100
    precision = 3
    n_data = 100
    data = jnp.array(rng.integers(0, 4, size=(n_data, size)))

    # x ~ Categorical(1 / 8, 2 / 8, 3 / 8, 2 / 8)
    m = m_init = rans.base_message(size, tail_capacity)
    choose = partial(jnp.choose, mode='clip')
    def enc_fun(x):
        return (choose(x, jnp.array([0, 1, 3, 6])),
                choose(x, jnp.array([1, 2, 3, 2])))

    def dec_fun(cf):
        return choose(cf, jnp.array([0, 1, 1, 2, 2, 2, 3, 3]))

    codec_push, codec_pop = rans.NonUniform(enc_fun, dec_fun, precision)

    _, freqs = enc_fun(data)

    # Encode
    def push_body(i, carry):
        m = carry
        m = codec_push(m, data[n_data - i - 1])
        return m
    m = lax.fori_loop(0, n_data, push_body, m)
    coded_arr = rans.flatten(m)
    assert coded_arr.dtype == np.uint8

    # Decode
    def pop_body(i, carry):
        m, xs = carry
        m, x = codec_pop(m)
        return m, lax.dynamic_update_index_in_dim(xs, x, i, 0)
    m = rans.unflatten(coded_arr, size, tail_capacity)
    m, data_decoded = lax.fori_loop(0, n_data, pop_body,
                                    (m, jnp.zeros((n_data, size), 'int32')))
    assert rans.message_equal(m, m_init)
Beispiel #28
0
def trace(state, fn, num_steps, **_):
    """Implementation of `trace` operator, without the calling convention."""
    # We need the shapes and dtypes of the outputs of `fn`.
    _, untraced_spec, traced_spec = jax.eval_shape(
        fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state))
    untraced_init = map_tree(lambda spec: np.zeros(spec.shape, spec.dtype),
                             untraced_spec)

    try:
        num_steps = int(num_steps)
        use_scan = True
    except TypeError:
        use_scan = False
        if flatten_tree(traced_spec):
            raise ValueError(
                'Cannot trace values when `num_steps` is not statically known. Pass '
                'False to `trace_mask` or return an empty structure (e.g. `()`) as '
                'the extra output.')

    if use_scan:

        def wrapper(state_untraced, _):
            state, _ = state_untraced
            state, untraced, traced = fn(state)
            return (state, untraced), traced

        (state, untraced), traced = lax.scan(
            wrapper,
            (state, untraced_init),
            xs=None,
            length=num_steps,
        )
    else:
        trace_arrays = map_tree(
            lambda spec: np.zeros((num_steps, ) + spec.shape, spec.dtype),
            traced_spec)

        def wrapper(i, state_untraced_traced):
            state, _, trace_arrays = state_untraced_traced
            state, untraced, traced = fn(state)
            trace_arrays = map_tree(lambda a, e: jax.ops.index_update(a, i, e),
                                    trace_arrays, traced)
            return (state, untraced, trace_arrays)

        state, untraced, traced = lax.fori_loop(
            np.asarray(0, num_steps.dtype),
            num_steps,
            wrapper,
            (state, untraced_init, trace_arrays),
        )
    return state, untraced, traced
Beispiel #29
0
    def sample(model, key):
        def body(i, loop_carry):
            key, config = loop_carry
            out = model(config)
            probs = prob(out)
            key, subkey = random.split(key)
            sample = random.bernoulli(subkey, probs[:, i, 1]) * 2 - 1.0
            sample = sample[..., jnp.newaxis]
            config = jax.ops.index_update(config, jax.ops.index[:, i], sample)
            return key, config

        key, config = fori_loop(0, init_config.shape[1], body,
                                (key, init_config))
        return key, config
Beispiel #30
0
    def test_nve_neighbor_list(self, spatial_dimension, dtype):
        Nx = particles_per_side = 8
        spacing = f32(1.25)

        tol = 5e-12 if dtype == np.float64 else 5e-3

        L = Nx * spacing
        if spatial_dimension == 2:
            R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing
        elif spatial_dimension == 3:
            R = np.stack([np.array(r)
                          for r in onp.ndindex(Nx, Nx, Nx)]) * spacing

        R = np.array(R, dtype)

        displacement, shift = space.periodic(L)

        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, L)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
        exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift,
                                                     1e-3)

        nbrs = neighbor_fn(R)
        state = init_fn(random.PRNGKey(0), R, neighbor=nbrs)
        exact_state = exact_init_fn(random.PRNGKey(0), R)

        def body_fn(i, state):
            state, nbrs, exact_state = state
            nbrs = neighbor_fn(state.position, nbrs)
            state = apply_fn(state, neighbor=nbrs)
            return state, nbrs, exact_apply_fn(exact_state)

        step = 0
        for i in range(20):
            new_state, nbrs, new_exact_state = lax.fori_loop(
                0, 100, body_fn, (state, nbrs, exact_state))
            if nbrs.did_buffer_overflow:
                nbrs = neighbor_fn(state.position)
            else:
                state = new_state
                exact_state = new_exact_state
                step += 1
        assert state.position.dtype == dtype
        self.assertAllClose(state.position,
                            exact_state.position,
                            atol=tol,
                            rtol=tol)