Example #1
0
    def test_broadcast(self):
        if jax.device_count() < 3:
            self.skipTest("test requires 3 devices")
        devices = self.get_devices()

        z = 1 + jnp.ones((2, 3))
        self.assert_uncommitted_to_device(z, devices[0])
        y = jax.device_put(1, devices[2]) + jnp.ones((2, 3))
        self.assert_committed_to_device(y, devices[2])
Example #2
0
def jit_simple_many_args_dispatch(n, state):
    args = [jax.device_put(i) for i in range(n)]
    f = jax.jit(lambda xs: functools.reduce(operator.add, xs))
    x = f(args)
    x.block_until_ready()

    while state:
        x = f(args)
    x.block_until_ready()
Example #3
0
def jit_big_matmul(state):
    x = np.random.uniform(size=(100, 100)).astype(np.float32)
    x = jax.device_put(x)

    f = jax.jit(lambda x: jnp.dot(x, x))
    f(x).block_until_ready()

    while state:
        f(x).block_until_ready()
Example #4
0
 def _update(i, af_state, af_inp):
     if trainable:
         af_inp = af_inp if isinstance(af_inp, tuple) else (af_inp, )
         af_inp = (af_inp + (0., ))[:2]
     else:
         af_inp = af_inp[0] if isinstance(af_inp, tuple) else af_inp
     af_inp = jax.device_put(af_inp)
     af_state, af_out = stop_gradient(update(i, af_state, af_inp))
     return af_state, af_out
Example #5
0
def jit_simple_pruned_args_dispatch(n, state):
    args = [jax.device_put(i) for i in range(n)]
    f = jax.jit(lambda *xs: xs[0] + 1)
    x = f(*args)
    x.block_until_ready()

    while state:
        x = f(*args)
    x.block_until_ready()
Example #6
0
def dbp_direct(y, H, c):
    y = device_put(y)
    H = device_put(H)
    c = device_put(c)

    steps = c.shape[0]

    fft = lambda x: jnp.fft.fft(x, axis=0)
    ifft = lambda x: jnp.fft.ifft(x, axis=0)

    D = jit(lambda y,H: ifft(fft(y) * H))
    N = jit(lambda y,c: y * jnp.exp(1j * (abs(y)**2 @ c)))

    for i in range(steps):
        y = D(y, H[i])
        y = N(y, c[i])

    return y
Example #7
0
 def unroll_and_push(self, frame_count: int, params: hk.Params):
     """Run one unroll and send trajectory to learner."""
     params = jax.device_put(params)
     self._rng_key, subkey = jax.random.split(self._rng_key)
     act_out = self.unroll(rng_key=subkey,
                           frame_count=frame_count,
                           params=params,
                           unroll_length=self._unroll_length)
     self._learner.enqueue_traj(act_out)
Example #8
0
 def cb(cb_inp):
   self.assertLen(cb_inp, 4)
   dbs = []
   for inp in cb_inp:
     index, devices = inp
     self.assertLen(devices, 2)
     array = global_input_data[index]
     dbs.extend([jax.device_put(array, device) for device in devices])
   return dbs
Example #9
0
def scan_wrapper(
    f,
    init,
    xs,
    length,
    reverse,
    rng_key=None,
    substitute_stack=[],
    enum=False,
    history=1,
    first_available_dim=None,
):
    if length is None:
        length = tree_flatten(xs)[0][0].shape[0]

    if enum and history > 0:
        return scan_enum(
            f,
            init,
            xs,
            length,
            reverse,
            rng_key,
            substitute_stack,
            history,
            first_available_dim,
        )

    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        with handlers.block():

            # we need to tell unconstrained messenger in potential energy computation
            # that only the item at time `i` is needed when transforming
            fn = handlers.infer_config(
                f, config_fn=lambda msg: {"_scan_current_index": i})

            seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn,
                                                   condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn,
                                                    substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

    wrapped_carry = device_put((0, rng_key, init))
    return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse)
Example #10
0
 def rvs(self, *args, **kwargs):
     rng = kwargs.pop('random_state')
     if rng is None:
         rng = self.random_state
     # assert that rng is PRNGKey and not mtrand.RandomState object from numpy.
     assert _is_prng_key(rng)
     kwargs['random_state'] = onp.random.RandomState(rng)
     sample = super(jax_discrete, self).rvs(*args, **kwargs)
     return device_put(sample)
Example #11
0
def main(argv):
    del argv
    rng = random.PRNGKey(0)
    rng, key = random.split(rng)

    train_ds = tfds.load('binarized_mnist', split=tfds.Split.TRAIN)
    train_ds = train_ds.map(prepare_image)
    train_ds = train_ds.cache()
    train_ds = train_ds.repeat()
    train_ds = train_ds.shuffle(50000)
    train_ds = train_ds.batch(FLAGS.batch_size)
    train_ds = tfds.as_numpy(train_ds)

    test_ds = tfds.load('binarized_mnist', split=tfds.Split.TEST)
    test_ds = test_ds.map(prepare_image).batch(10000)
    test_ds = np.array(list(test_ds)[0])
    test_ds = jax.device_put(test_ds)

    model_def = VAE.partial(latents=FLAGS.latents)
    _, params = model_def.init_by_shape(key, [(FLAGS.batch_size, 784)],
                                        z_rng=random.PRNGKey(0))
    vae = nn.Model(model_def, params)

    optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae)
    optimizer = jax.device_put(optimizer)

    rng, z_key, eval_rng = random.split(rng, 3)
    z = random.normal(z_key, (64, FLAGS.latents))

    steps_per_epoch = 50000 // FLAGS.batch_size

    for epoch in range(FLAGS.num_epochs):
        for _ in range(steps_per_epoch):
            batch = next(train_ds)
            rng, key = random.split(rng)
            optimizer = train_step(optimizer, batch, key)

        metrics, comparison, sample = eval(optimizer.target, test_ds, z,
                                           eval_rng)
        save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8)
        save_image(sample, f'results/sample_{epoch}.png', nrow=8)

        print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(
            epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
Example #12
0
def main(argv):
    del argv
    rng = random.PRNGKey(0)
    rng, key = random.split(rng)

    ds_builder = tfds.builder('binarized_mnist')
    ds_builder.download_and_prepare()
    train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
    train_ds = train_ds.map(prepare_image)
    train_ds = train_ds.cache()
    train_ds = train_ds.repeat()
    train_ds = train_ds.shuffle(50000)
    train_ds = train_ds.batch(FLAGS.batch_size)
    train_ds = iter(tfds.as_numpy(train_ds))

    test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
    test_ds = test_ds.map(prepare_image).batch(10000)
    test_ds = np.array(list(test_ds)[0])
    test_ds = jax.device_put(test_ds)

    init_data = jnp.ones((FLAGS.batch_size, 784), jnp.float32)
    params = model().init(key, init_data, rng)['param']

    optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(params)
    optimizer = jax.device_put(optimizer)

    rng, z_key, eval_rng = random.split(rng, 3)
    z = random.normal(z_key, (64, FLAGS.latents))

    steps_per_epoch = 50000 // FLAGS.batch_size

    for epoch in range(FLAGS.num_epochs):
        for _ in range(steps_per_epoch):
            batch = next(train_ds)
            rng, key = random.split(rng)
            optimizer = train_step(optimizer, batch, key)

        metrics, comparison, sample = eval(optimizer.target, test_ds, z,
                                           eval_rng)
        save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8)
        save_image(sample, f'results/sample_{epoch}.png', nrow=8)

        print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(
            epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
Example #13
0
def inverse_ptt_params(left_index, right_index, leaf_index):
    num_nodes = len(left_index)
    n = (num_nodes + 1) // 2

    # NOTE: keep in mind, are serialized in dfs order, but for whatever reason
    # visiting the right node first, so this whole thing is backwards from what
    # you might expect.

    # compute leaf permutation
    leaf_permutation = np.zeros(n, np.int32)
    min_leaf_index = np.zeros([num_nodes], np.int)
    max_leaf_index = np.zeros([num_nodes], np.int)
    k = 0  # leaf node number
    for i in range(num_nodes):
        if leaf_index[i] >= 0:
            leaf_permutation[k] = leaf_index[i]
            min_leaf_index[i] = k
            max_leaf_index[i] = k
            k += 1
    assert k == n

    # figure out subtree spans for every node
    for i in range(num_nodes - 1, -1, -1):
        if leaf_index[i] < 0:
            min_leaf_index[i] = min_leaf_index[right_index[i]]
            max_leaf_index[i] = max_leaf_index[left_index[i]]
            assert min_leaf_index[i] < max_leaf_index[i]

    # now just compute the indexes we need to compute values for internal nodes
    max_leaf = np.zeros(n - 1, np.int32)
    min_leaf = np.zeros(n - 1, np.int32)
    left_min_leaf = np.zeros(n - 1, np.int32)

    k = 0  # internal node number
    for i in range(num_nodes):
        if leaf_index[i] >= 0:
            continue
        max_leaf[k] = max_leaf_index[i] + 1
        min_leaf[k] = min_leaf_index[i]
        left_min_leaf[k] = min_leaf_index[left_index[i]]
        k += 1

    return PttArgs(jax.device_put(leaf_permutation), jax.device_put(max_leaf),
                   jax.device_put(min_leaf), jax.device_put(left_min_leaf))
Example #14
0
def _testop(spec, tmpdir):
    img = jax.device_put(np.random.uniform(size=IMG_SHAPE))
    op = make_operation(spec)
    assert op(img).shape == (op.output_size(IMG_SHAPE),)

    with open(tmpdir / "op.pkl", "wb") as fi:
        joblib.dump(op, fi) 
    pkl_op = type(op).fromfile(tmpdir / "op.pkl")
    assert type(pkl_op(img)) == type(img)
    return op
Example #15
0
def main():
    loss_obj = hk.transform(loss_fn, apply_rng=True)
    # Initial parameter values are typically random. In JAX you need a key in order
    # to generate random numbers and so Haiku requires you to pass one in.
    rng = PRNGSequence(42)

    # `init` runs your function, as such we need an example input. Typically you can
    # pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
    # is not usually data dependent.
    shape = [([1000], float)]
    adam = optim.Adam(learning_rate=0.1)
    partial = Net.partial()
    _, params = partial.init_by_shape(next(rng), shape)
    net = nn.Model(partial, params)

    optimizer = jax.device_put(adam.create(net))
    _, params = partial.init_by_shape(next(rng), shape)  # HERE
    net = net.replace(params=params)
    optimizer = jax.device_put(adam.create(net))  # HERE
Example #16
0
def _host_to_device_funcs():
  """Generates host-to-device transfer functions."""
  return [
      # (function name, is an explicit transfer?, function)
      ("host_to_device_jax_device_put", True,
       lambda: jax.device_put(np.ones(10))),
      ("host_to_device_jax_jit", False, lambda: jax.jit(lambda x: x)
       (np.ones(1))),
      ("host_to_device_jnp_one", False, lambda: jnp.ones(1)),
  ]
Example #17
0
  def _res_tf_to_jax(res_tf: TfVal):
    res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
    if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
      res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
      res_jax_platform = res_tf_platform.lower()
      if res_jax_platform in _DLPACK_PLATFORMS:
        res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
        return jax.dlpack.from_dlpack(res_dlpack)

    return jax.device_put(np.asarray(res_tf))
Example #18
0
    def test_device_put_and_get(self):
        x = onp.arange(12.).reshape((3, 4)).astype("float32")
        dx = device_put(x)
        assert isinstance(dx, DeviceArray)
        x2 = device_get(dx)
        assert isinstance(x2, onp.ndarray)
        assert onp.all(x == x2)

        y = [x, (2 * x, 3 * x)]
        dy = device_put(y)
        y2 = device_get(dy)
        assert isinstance(y2, list)
        assert isinstance(y2[0], onp.ndarray)
        assert onp.all(y2[0] == x)
        assert isinstance(y2[1], tuple)
        assert isinstance(y2[1][0], onp.ndarray)
        assert onp.all(y2[1][0] == 2 * x)
        assert isinstance(y2[1][1], onp.ndarray)
        assert onp.all(y2[1][1] == 3 * x)
Example #19
0
    def load_pickle(inname, metric=None, device=jax.devices()[0]):
        loaded_som = pickle.load(open(inname, 'rb'))
        if metric is None:
            print(
                'WARNING : loading a JAX SOM object without its metric results in the cdist metric being used '
                'because jitted functions can not be pickled. If you want to use another metric, please '
                'specify it in the SOM.load_pickle function')
            metric = jax_cdist
        loaded_som.device = device
        loaded_som.metric = metric

        # We need to manually choose which arrays are to be put, otherwise every numpy item ends up a jax object.
        loaded_som.centroids = jax.device_put(loaded_som.centroids,
                                              device=device)
        loaded_som.locations = jax.device_put(loaded_som.locations,
                                              device=device)
        loaded_som.distance_mat = jax.device_put(loaded_som.distance_mat,
                                                 device=device)
        return loaded_som
Example #20
0
  def testBasics(self):
    client = jax.lib.xla_bridge.get_backend()
    _ = client.heap_profile()

    a = jax.device_put(1)
    _ = client.heap_profile()

    # Heap profiler doesn't crash with deleted buffer
    a.delete()
    _ = client.heap_profile()
Example #21
0
    def test_transpose(self):
        devices = self.get_devices()

        x = jnp.ones((2, 3))
        self.assert_uncommitted_to_device(x, devices[0])

        y = lax.transpose(x, (1, 0))
        self.assert_uncommitted_to_device(y, devices[0])
        z = lax.transpose(jax.device_put(x, devices[2]), (1, 0))
        self.assert_committed_to_device(z, devices[2])
Example #22
0
 def init_kernel(init_params,
                 num_warmup,
                 adapt_state_size=None,
                 inverse_mass_matrix=None,
                 dense_mass=False,
                 model_args=(),
                 model_kwargs=None,
                 rng_key=random.PRNGKey(0)):
     nonlocal wa_steps
     wa_steps = num_warmup
     pe_fn = potential_fn
     if potential_fn_gen:
         if pe_fn is not None:
             raise ValueError(
                 'Only one of `potential_fn` or `potential_fn_gen` must be provided.'
             )
         else:
             kwargs = {} if model_kwargs is None else model_kwargs
             pe_fn = potential_fn_gen(*model_args, **kwargs)
     rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3)
     z = init_params
     z_flat, unravel_fn = ravel_pytree(z)
     if inverse_mass_matrix is None:
         inverse_mass_matrix = jnp.identity(
             z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1])
     inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \
         else jnp.sqrt(inverse_mass_matrix)
     if adapt_state_size is None:
         # XXX: heuristic choice
         adapt_state_size = 2 * z_flat.shape[-1]
     else:
         assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.'
     # NB: mean is init_params
     zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs,
                                    (adapt_state_size, ))
     # compute potential energies
     pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs)
     if dense_mass:
         cov = jnp.cov(zs, rowvar=False, bias=True)
         if cov.shape == ():  # JAX returns scalar for 1D input
             cov = cov.reshape((1, 1))
         cholesky = jnp.linalg.cholesky(cov)
         # if cholesky is NaN, we use the scale from `sample_proposal` here
         inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)),
                                          inv_mass_matrix_sqrt, cholesky)
     else:
         inv_mass_matrix_sqrt = jnp.std(zs, 0)
     adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0),
                                inv_mass_matrix_sqrt)
     k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0]))
     z = unravel_fn(zs[k])
     pe = pes[k]
     sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()),
                        jnp.array(False), adapt_state, rng_key_sa)
     return device_put(sa_state)
Example #23
0
    def __make_transition_matrices(self):
        nu = np.arange(1, self.nu_max + 1)
        mu = np.arange(4., 45., 1.)
        p = mu[..., None] / (nu[None] + mu[..., None])

        j = np.tril(nu[None].repeat(self.nu_max, 0))
        bnm = binom(nu[..., None], np.tril(j - 1))
        p1 = p[..., None]**np.tril((nu[..., None] - j + 1))
        p2 = (1 - p[..., None])**np.tril(j - 1)
        pi = np.tril(bnm * p1 * p2)
        pi = np.concatenate([pi, 1 - pi.sum(-1, keepdims=True)], -1)

        # phase transition matrix for different models m, f_t| f_{t+1}
        p_ff = []
        for i in range(self.nu_max):
            vp = p[..., i:i + 1].repeat(i + 1, -1)
            tmp = np.concatenate([
                vdiag(vp) + voffd(1 - vp[..., :-1]),
                np.zeros((p.shape[0], i + 1, self.nu_max - i - 1))
            ], -1)
            tmp = np.concatenate([tmp, 1 - tmp.sum(-1, keepdims=True)], -1)
            tmp = np.concatenate([
                tmp,
                np.ones((p.shape[0], self.nu_max - i - 1, self.nu_max + 1)) /
                (self.nu_max + 1)
            ], -2)
            tmp = np.concatenate([tmp, pi[:, i:i + 1]], -2)
            p_ff.append(tmp)

        self.p_mff = device_put(
            jnp.stack(p_ff, -3).reshape(-1, self.nu_max + 1, self.nu_max + 1),
            self.device)

        # state transition matrix f_t, j_t| j_{t+1}
        p_fjj = np.zeros((self.nu_max + 1, 2, 2))
        p_fjj[-1, :, 1] = 1.
        p_fjj[:-1, :, 0] = 1.

        p_jcc = np.stack([np.eye(2), (np.ones((2, 2)) - np.eye(2))], 0)

        self.p_fcc = einsum('jcz,fj->fcz', device_put(p_jcc, self.device),
                            device_put(p_fjj[:, 0], self.device))
Example #24
0
def jit_dispatch_without_transfer(state):
  # We pick up a realistic input. 224 is usual for classification and 128 a
  # TPU-friendly batch-size.
  imgs = np.ones((128, 224, 224), np.float32)
  imgs = jax.device_put(imgs)

  f = jax.api.jit(lambda x: x+1)
  f(imgs)

  while state:
    f(imgs)
Example #25
0
def dbp_timedomain(y, h, c):

    y = device_put(y)
    h = device_put(h)
    c = device_put(c)

    steps = c.shape[0]

    D = jit(
        vmap(lambda y, h: xop.conv1d_fft_oa(y, h, mode='SAME'),
             in_axes=1,
             out_axes=1))
    # D = jit(vmap(lambda y,h: xop.conv1d_lax(y, h), in_axes=1, out_axes=1)) # often too slow for long h
    N = jit(lambda y, c: y * jnp.exp(1j * (abs(y)**2 @ c)))

    for i in range(steps):
        y = D(y, h[i])
        y = N(y, c[i])

    return y
Example #26
0
 def test_device_put(self):
     with jax._src.config.jax_array(True):
         numpy_array = np.array([1, 2, 3])
         arr = jax.device_put(numpy_array, jax.devices()[0])
         self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
         self.assertArraysEqual(arr, numpy_array)
         self.assertEqual(arr._committed, True)
         for i in arr.addressable_shards:
             self.assertArraysEqual(i.data, numpy_array)
             self.assertEqual(i.device, jax.devices()[0])
             self.assertEqual(i.index, (slice(None), ))
Example #27
0
def lms_cpane(signal, w_init, data=None, train=None, lr=1e-4, beta=0.7, const=comm.const("16QAM", norm=True), device=cpus[0]):
    const = comm.const("16QAM", norm=True)

    if train is None:
        train = np.full((signal.shape[0],), False)
        data = np.full((signal.shape[0],), 0, dtype=const.dtype)

    dims = signal.shape[-1]

    params_lms = (w_init, lr)
    params_cpane = tuple(map(lambda x: np.tile(x, dims), [1e-5 * (1.+1j), 1e-2 * (1.+1j), 0j, 1j, 0j, beta])) + (const,)
    params = (params_lms, params_cpane)
    inputs = (signal, data, train)

    params = device_put(params, device)
    inputs = device_put(inputs, device)

    _, ret = scan(step_lms_cpane, params, inputs)

    return ret
Example #28
0
def prepare_tensor(g, data, name):
    """Convert the data to ID tensor and check its ID type and context.

    If the data is already in tensor type, raise error if its ID type
    and context does not match the graph's.
    Otherwise, convert it to tensor type of the graph's ID type and
    ctx and return.

    Parameters
    ----------
    g : DGLHeteroGraph
        Graph.
    data : int, iterable of int, tensor
        Data.
    name : str
        Name of the data.

    Returns
    -------
    Tensor
        Data in tensor object.
    """
    if F.backend_name == "jax":
        import jax
        from jax import numpy as jnp
        if isinstance(data, jnp.ndarray) and hasattr(data, 'device_buffer'):
            if data.device_buffer.device() != g.device:

                data = jax.device_put(data, g.device).astype(data.dtype)

    if F.is_tensor(data):
        if F.dtype(data) != g.idtype or F.context(data) != g.device:
            raise DGLError(
                'Expect argument "{}" to have data type {} and device '
                'context {}. But got {} and {}.'.format(
                    name, g.idtype, g.device, F.dtype(data), F.context(data)))
        ret = data
    else:
        data = F.tensor(data)
        if (not (F.ndim(data) > 0 and F.shape(data)[0] == 0)
                and  # empty tensor
                F.dtype(data) not in (F.int32, F.int64)):
            raise DGLError(
                'Expect argument "{}" to have data type int32 or int64,'
                ' but got {}.'.format(name, F.dtype(data)))
        ret = F.copy_to(F.astype(data, g.idtype), g.device)

    if F.ndim(ret) == 0:
        ret = F.unsqueeze(ret, 0)
    if F.ndim(ret) > 1:
        raise DGLError(
            'Expect a 1-D tensor for argument "{}". But got {}.'.format(
                name, ret))
    return ret
Example #29
0
    def __set_prior(self, a=6., b=32.):
        prior_mf = self.p_mff[:, -1]
        M = prior_mf.shape[0]
        prior_m = np.ones(M) / M
        prior_m = prior_m.reshape(-1, self.nu_max)
        prior_m = np.concatenate([
            np.zeros_like(prior_m[:, :self.nu_min]), prior_m[:, self.nu_min:]
        ], -1).reshape(-1)
        prior_m /= prior_m.sum()
        prior_fm = (prior_mf * device_put(prior_m[:, None], self.device)).T
        prior_c = device_put(jnp.ones(2) / 2, self.device)

        pars = device_put(
            jnp.array([[[a, b, 1, 1], [b, a, 1, 1]],
                       [[b, a, 1, 1], [a, b, 1, 1]],
                       [[1, 1, 1000., 1], [1, 1, 1, 1000.]]]),
            self.device)[None].repeat(self.N, 0)

        probs = einsum('c,fm->cfm', prior_c, prior_fm)[None].repeat(self.N, 0)
        self.prior = (probs, pars)
def make_dataset(points_per_class, classes, revolutions=4):
  np.random.seed(0)

  N = points_per_class
  C = classes
  pi = np.pi

  X = np.zeros((N * C, 2))
  y = np.zeros((N * C, C))

  for j in range(C):
    ix = range(N * j, N * (j + 1))
    r = np.linspace(0., 1, N) # radius
    omega = 2 * pi / C
    theta_max = revolutions * pi
    t = np.linspace(omega * j,omega * j + theta_max, N) + np.random.randn(N) * 0.2 # theta
    X[ix] = np.c_[r*np.cos(t), r*np.sin(t)]
    y[ix, j] = 1

  return jax.device_put(X), jax.device_put(y)