def test_broadcast(self): if jax.device_count() < 3: self.skipTest("test requires 3 devices") devices = self.get_devices() z = 1 + jnp.ones((2, 3)) self.assert_uncommitted_to_device(z, devices[0]) y = jax.device_put(1, devices[2]) + jnp.ones((2, 3)) self.assert_committed_to_device(y, devices[2])
def jit_simple_many_args_dispatch(n, state): args = [jax.device_put(i) for i in range(n)] f = jax.jit(lambda xs: functools.reduce(operator.add, xs)) x = f(args) x.block_until_ready() while state: x = f(args) x.block_until_ready()
def jit_big_matmul(state): x = np.random.uniform(size=(100, 100)).astype(np.float32) x = jax.device_put(x) f = jax.jit(lambda x: jnp.dot(x, x)) f(x).block_until_ready() while state: f(x).block_until_ready()
def _update(i, af_state, af_inp): if trainable: af_inp = af_inp if isinstance(af_inp, tuple) else (af_inp, ) af_inp = (af_inp + (0., ))[:2] else: af_inp = af_inp[0] if isinstance(af_inp, tuple) else af_inp af_inp = jax.device_put(af_inp) af_state, af_out = stop_gradient(update(i, af_state, af_inp)) return af_state, af_out
def jit_simple_pruned_args_dispatch(n, state): args = [jax.device_put(i) for i in range(n)] f = jax.jit(lambda *xs: xs[0] + 1) x = f(*args) x.block_until_ready() while state: x = f(*args) x.block_until_ready()
def dbp_direct(y, H, c): y = device_put(y) H = device_put(H) c = device_put(c) steps = c.shape[0] fft = lambda x: jnp.fft.fft(x, axis=0) ifft = lambda x: jnp.fft.ifft(x, axis=0) D = jit(lambda y,H: ifft(fft(y) * H)) N = jit(lambda y,c: y * jnp.exp(1j * (abs(y)**2 @ c))) for i in range(steps): y = D(y, H[i]) y = N(y, c[i]) return y
def unroll_and_push(self, frame_count: int, params: hk.Params): """Run one unroll and send trajectory to learner.""" params = jax.device_put(params) self._rng_key, subkey = jax.random.split(self._rng_key) act_out = self.unroll(rng_key=subkey, frame_count=frame_count, params=params, unroll_length=self._unroll_length) self._learner.enqueue_traj(act_out)
def cb(cb_inp): self.assertLen(cb_inp, 4) dbs = [] for inp in cb_inp: index, devices = inp self.assertLen(devices, 2) array = global_input_data[index] dbs.extend([jax.device_put(array, device) for device in devices]) return dbs
def scan_wrapper( f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False, history=1, first_available_dim=None, ): if length is None: length = tree_flatten(xs)[0][0].shape[0] if enum and history > 0: return scan_enum( f, init, xs, length, reverse, rng_key, substitute_stack, history, first_available_dim, ) def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) with handlers.block(): # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) return (i + 1, rng_key, carry), (PytreeTrace(trace), y) wrapped_carry = device_put((0, rng_key, init)) return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse)
def rvs(self, *args, **kwargs): rng = kwargs.pop('random_state') if rng is None: rng = self.random_state # assert that rng is PRNGKey and not mtrand.RandomState object from numpy. assert _is_prng_key(rng) kwargs['random_state'] = onp.random.RandomState(rng) sample = super(jax_discrete, self).rvs(*args, **kwargs) return device_put(sample)
def main(argv): del argv rng = random.PRNGKey(0) rng, key = random.split(rng) train_ds = tfds.load('binarized_mnist', split=tfds.Split.TRAIN) train_ds = train_ds.map(prepare_image) train_ds = train_ds.cache() train_ds = train_ds.repeat() train_ds = train_ds.shuffle(50000) train_ds = train_ds.batch(FLAGS.batch_size) train_ds = tfds.as_numpy(train_ds) test_ds = tfds.load('binarized_mnist', split=tfds.Split.TEST) test_ds = test_ds.map(prepare_image).batch(10000) test_ds = np.array(list(test_ds)[0]) test_ds = jax.device_put(test_ds) model_def = VAE.partial(latents=FLAGS.latents) _, params = model_def.init_by_shape(key, [(FLAGS.batch_size, 784)], z_rng=random.PRNGKey(0)) vae = nn.Model(model_def, params) optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae) optimizer = jax.device_put(optimizer) rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, FLAGS.latents)) steps_per_epoch = 50000 // FLAGS.batch_size for epoch in range(FLAGS.num_epochs): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) optimizer = train_step(optimizer, batch, key) metrics, comparison, sample = eval(optimizer.target, test_ds, z, eval_rng) save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8) save_image(sample, f'results/sample_{epoch}.png', nrow=8) print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
def main(argv): del argv rng = random.PRNGKey(0) rng, key = random.split(rng) ds_builder = tfds.builder('binarized_mnist') ds_builder.download_and_prepare() train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN) train_ds = train_ds.map(prepare_image) train_ds = train_ds.cache() train_ds = train_ds.repeat() train_ds = train_ds.shuffle(50000) train_ds = train_ds.batch(FLAGS.batch_size) train_ds = iter(tfds.as_numpy(train_ds)) test_ds = ds_builder.as_dataset(split=tfds.Split.TEST) test_ds = test_ds.map(prepare_image).batch(10000) test_ds = np.array(list(test_ds)[0]) test_ds = jax.device_put(test_ds) init_data = jnp.ones((FLAGS.batch_size, 784), jnp.float32) params = model().init(key, init_data, rng)['param'] optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(params) optimizer = jax.device_put(optimizer) rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, FLAGS.latents)) steps_per_epoch = 50000 // FLAGS.batch_size for epoch in range(FLAGS.num_epochs): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) optimizer = train_step(optimizer, batch, key) metrics, comparison, sample = eval(optimizer.target, test_ds, z, eval_rng) save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8) save_image(sample, f'results/sample_{epoch}.png', nrow=8) print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']))
def inverse_ptt_params(left_index, right_index, leaf_index): num_nodes = len(left_index) n = (num_nodes + 1) // 2 # NOTE: keep in mind, are serialized in dfs order, but for whatever reason # visiting the right node first, so this whole thing is backwards from what # you might expect. # compute leaf permutation leaf_permutation = np.zeros(n, np.int32) min_leaf_index = np.zeros([num_nodes], np.int) max_leaf_index = np.zeros([num_nodes], np.int) k = 0 # leaf node number for i in range(num_nodes): if leaf_index[i] >= 0: leaf_permutation[k] = leaf_index[i] min_leaf_index[i] = k max_leaf_index[i] = k k += 1 assert k == n # figure out subtree spans for every node for i in range(num_nodes - 1, -1, -1): if leaf_index[i] < 0: min_leaf_index[i] = min_leaf_index[right_index[i]] max_leaf_index[i] = max_leaf_index[left_index[i]] assert min_leaf_index[i] < max_leaf_index[i] # now just compute the indexes we need to compute values for internal nodes max_leaf = np.zeros(n - 1, np.int32) min_leaf = np.zeros(n - 1, np.int32) left_min_leaf = np.zeros(n - 1, np.int32) k = 0 # internal node number for i in range(num_nodes): if leaf_index[i] >= 0: continue max_leaf[k] = max_leaf_index[i] + 1 min_leaf[k] = min_leaf_index[i] left_min_leaf[k] = min_leaf_index[left_index[i]] k += 1 return PttArgs(jax.device_put(leaf_permutation), jax.device_put(max_leaf), jax.device_put(min_leaf), jax.device_put(left_min_leaf))
def _testop(spec, tmpdir): img = jax.device_put(np.random.uniform(size=IMG_SHAPE)) op = make_operation(spec) assert op(img).shape == (op.output_size(IMG_SHAPE),) with open(tmpdir / "op.pkl", "wb") as fi: joblib.dump(op, fi) pkl_op = type(op).fromfile(tmpdir / "op.pkl") assert type(pkl_op(img)) == type(img) return op
def main(): loss_obj = hk.transform(loss_fn, apply_rng=True) # Initial parameter values are typically random. In JAX you need a key in order # to generate random numbers and so Haiku requires you to pass one in. rng = PRNGSequence(42) # `init` runs your function, as such we need an example input. Typically you can # pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization # is not usually data dependent. shape = [([1000], float)] adam = optim.Adam(learning_rate=0.1) partial = Net.partial() _, params = partial.init_by_shape(next(rng), shape) net = nn.Model(partial, params) optimizer = jax.device_put(adam.create(net)) _, params = partial.init_by_shape(next(rng), shape) # HERE net = net.replace(params=params) optimizer = jax.device_put(adam.create(net)) # HERE
def _host_to_device_funcs(): """Generates host-to-device transfer functions.""" return [ # (function name, is an explicit transfer?, function) ("host_to_device_jax_device_put", True, lambda: jax.device_put(np.ones(10))), ("host_to_device_jax_jit", False, lambda: jax.jit(lambda x: x) (np.ones(1))), ("host_to_device_jnp_one", False, lambda: jnp.ones(1)), ]
def _res_tf_to_jax(res_tf: TfVal): res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf) if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES: res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() if res_jax_platform in _DLPACK_PLATFORMS: res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack(res_dlpack) return jax.device_put(np.asarray(res_tf))
def test_device_put_and_get(self): x = onp.arange(12.).reshape((3, 4)).astype("float32") dx = device_put(x) assert isinstance(dx, DeviceArray) x2 = device_get(dx) assert isinstance(x2, onp.ndarray) assert onp.all(x == x2) y = [x, (2 * x, 3 * x)] dy = device_put(y) y2 = device_get(dy) assert isinstance(y2, list) assert isinstance(y2[0], onp.ndarray) assert onp.all(y2[0] == x) assert isinstance(y2[1], tuple) assert isinstance(y2[1][0], onp.ndarray) assert onp.all(y2[1][0] == 2 * x) assert isinstance(y2[1][1], onp.ndarray) assert onp.all(y2[1][1] == 3 * x)
def load_pickle(inname, metric=None, device=jax.devices()[0]): loaded_som = pickle.load(open(inname, 'rb')) if metric is None: print( 'WARNING : loading a JAX SOM object without its metric results in the cdist metric being used ' 'because jitted functions can not be pickled. If you want to use another metric, please ' 'specify it in the SOM.load_pickle function') metric = jax_cdist loaded_som.device = device loaded_som.metric = metric # We need to manually choose which arrays are to be put, otherwise every numpy item ends up a jax object. loaded_som.centroids = jax.device_put(loaded_som.centroids, device=device) loaded_som.locations = jax.device_put(loaded_som.locations, device=device) loaded_som.distance_mat = jax.device_put(loaded_som.distance_mat, device=device) return loaded_som
def testBasics(self): client = jax.lib.xla_bridge.get_backend() _ = client.heap_profile() a = jax.device_put(1) _ = client.heap_profile() # Heap profiler doesn't crash with deleted buffer a.delete() _ = client.heap_profile()
def test_transpose(self): devices = self.get_devices() x = jnp.ones((2, 3)) self.assert_uncommitted_to_device(x, devices[0]) y = lax.transpose(x, (1, 0)) self.assert_uncommitted_to_device(y, devices[0]) z = lax.transpose(jax.device_put(x, devices[2]), (1, 0)) self.assert_committed_to_device(z, devices[2])
def init_kernel(init_params, num_warmup, adapt_state_size=None, inverse_mass_matrix=None, dense_mass=False, model_args=(), model_kwargs=None, rng_key=random.PRNGKey(0)): nonlocal wa_steps wa_steps = num_warmup pe_fn = potential_fn if potential_fn_gen: if pe_fn is not None: raise ValueError( 'Only one of `potential_fn` or `potential_fn_gen` must be provided.' ) else: kwargs = {} if model_kwargs is None else model_kwargs pe_fn = potential_fn_gen(*model_args, **kwargs) rng_key_sa, rng_key_zs, rng_key_z = random.split(rng_key, 3) z = init_params z_flat, unravel_fn = ravel_pytree(z) if inverse_mass_matrix is None: inverse_mass_matrix = jnp.identity( z_flat.shape[-1]) if dense_mass else jnp.ones(z_flat.shape[-1]) inv_mass_matrix_sqrt = jnp.linalg.cholesky(inverse_mass_matrix) if dense_mass \ else jnp.sqrt(inverse_mass_matrix) if adapt_state_size is None: # XXX: heuristic choice adapt_state_size = 2 * z_flat.shape[-1] else: assert adapt_state_size > 1, 'adapt_state_size should be greater than 1.' # NB: mean is init_params zs = z_flat + _sample_proposal(inv_mass_matrix_sqrt, rng_key_zs, (adapt_state_size, )) # compute potential energies pes = lax.map(lambda z: pe_fn(unravel_fn(z)), zs) if dense_mass: cov = jnp.cov(zs, rowvar=False, bias=True) if cov.shape == (): # JAX returns scalar for 1D input cov = cov.reshape((1, 1)) cholesky = jnp.linalg.cholesky(cov) # if cholesky is NaN, we use the scale from `sample_proposal` here inv_mass_matrix_sqrt = jnp.where(jnp.any(jnp.isnan(cholesky)), inv_mass_matrix_sqrt, cholesky) else: inv_mass_matrix_sqrt = jnp.std(zs, 0) adapt_state = SAAdaptState(zs, pes, jnp.mean(zs, 0), inv_mass_matrix_sqrt) k = random.categorical(rng_key_z, jnp.zeros(zs.shape[0])) z = unravel_fn(zs[k]) pe = pes[k] sa_state = SAState(jnp.array(0), z, pe, jnp.zeros(()), jnp.zeros(()), jnp.array(False), adapt_state, rng_key_sa) return device_put(sa_state)
def __make_transition_matrices(self): nu = np.arange(1, self.nu_max + 1) mu = np.arange(4., 45., 1.) p = mu[..., None] / (nu[None] + mu[..., None]) j = np.tril(nu[None].repeat(self.nu_max, 0)) bnm = binom(nu[..., None], np.tril(j - 1)) p1 = p[..., None]**np.tril((nu[..., None] - j + 1)) p2 = (1 - p[..., None])**np.tril(j - 1) pi = np.tril(bnm * p1 * p2) pi = np.concatenate([pi, 1 - pi.sum(-1, keepdims=True)], -1) # phase transition matrix for different models m, f_t| f_{t+1} p_ff = [] for i in range(self.nu_max): vp = p[..., i:i + 1].repeat(i + 1, -1) tmp = np.concatenate([ vdiag(vp) + voffd(1 - vp[..., :-1]), np.zeros((p.shape[0], i + 1, self.nu_max - i - 1)) ], -1) tmp = np.concatenate([tmp, 1 - tmp.sum(-1, keepdims=True)], -1) tmp = np.concatenate([ tmp, np.ones((p.shape[0], self.nu_max - i - 1, self.nu_max + 1)) / (self.nu_max + 1) ], -2) tmp = np.concatenate([tmp, pi[:, i:i + 1]], -2) p_ff.append(tmp) self.p_mff = device_put( jnp.stack(p_ff, -3).reshape(-1, self.nu_max + 1, self.nu_max + 1), self.device) # state transition matrix f_t, j_t| j_{t+1} p_fjj = np.zeros((self.nu_max + 1, 2, 2)) p_fjj[-1, :, 1] = 1. p_fjj[:-1, :, 0] = 1. p_jcc = np.stack([np.eye(2), (np.ones((2, 2)) - np.eye(2))], 0) self.p_fcc = einsum('jcz,fj->fcz', device_put(p_jcc, self.device), device_put(p_fjj[:, 0], self.device))
def jit_dispatch_without_transfer(state): # We pick up a realistic input. 224 is usual for classification and 128 a # TPU-friendly batch-size. imgs = np.ones((128, 224, 224), np.float32) imgs = jax.device_put(imgs) f = jax.api.jit(lambda x: x+1) f(imgs) while state: f(imgs)
def dbp_timedomain(y, h, c): y = device_put(y) h = device_put(h) c = device_put(c) steps = c.shape[0] D = jit( vmap(lambda y, h: xop.conv1d_fft_oa(y, h, mode='SAME'), in_axes=1, out_axes=1)) # D = jit(vmap(lambda y,h: xop.conv1d_lax(y, h), in_axes=1, out_axes=1)) # often too slow for long h N = jit(lambda y, c: y * jnp.exp(1j * (abs(y)**2 @ c))) for i in range(steps): y = D(y, h[i]) y = N(y, c[i]) return y
def test_device_put(self): with jax._src.config.jax_array(True): numpy_array = np.array([1, 2, 3]) arr = jax.device_put(numpy_array, jax.devices()[0]) self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) self.assertArraysEqual(arr, numpy_array) self.assertEqual(arr._committed, True) for i in arr.addressable_shards: self.assertArraysEqual(i.data, numpy_array) self.assertEqual(i.device, jax.devices()[0]) self.assertEqual(i.index, (slice(None), ))
def lms_cpane(signal, w_init, data=None, train=None, lr=1e-4, beta=0.7, const=comm.const("16QAM", norm=True), device=cpus[0]): const = comm.const("16QAM", norm=True) if train is None: train = np.full((signal.shape[0],), False) data = np.full((signal.shape[0],), 0, dtype=const.dtype) dims = signal.shape[-1] params_lms = (w_init, lr) params_cpane = tuple(map(lambda x: np.tile(x, dims), [1e-5 * (1.+1j), 1e-2 * (1.+1j), 0j, 1j, 0j, beta])) + (const,) params = (params_lms, params_cpane) inputs = (signal, data, train) params = device_put(params, device) inputs = device_put(inputs, device) _, ret = scan(step_lms_cpane, params, inputs) return ret
def prepare_tensor(g, data, name): """Convert the data to ID tensor and check its ID type and context. If the data is already in tensor type, raise error if its ID type and context does not match the graph's. Otherwise, convert it to tensor type of the graph's ID type and ctx and return. Parameters ---------- g : DGLHeteroGraph Graph. data : int, iterable of int, tensor Data. name : str Name of the data. Returns ------- Tensor Data in tensor object. """ if F.backend_name == "jax": import jax from jax import numpy as jnp if isinstance(data, jnp.ndarray) and hasattr(data, 'device_buffer'): if data.device_buffer.device() != g.device: data = jax.device_put(data, g.device).astype(data.dtype) if F.is_tensor(data): if F.dtype(data) != g.idtype or F.context(data) != g.device: raise DGLError( 'Expect argument "{}" to have data type {} and device ' 'context {}. But got {} and {}.'.format( name, g.idtype, g.device, F.dtype(data), F.context(data))) ret = data else: data = F.tensor(data) if (not (F.ndim(data) > 0 and F.shape(data)[0] == 0) and # empty tensor F.dtype(data) not in (F.int32, F.int64)): raise DGLError( 'Expect argument "{}" to have data type int32 or int64,' ' but got {}.'.format(name, F.dtype(data))) ret = F.copy_to(F.astype(data, g.idtype), g.device) if F.ndim(ret) == 0: ret = F.unsqueeze(ret, 0) if F.ndim(ret) > 1: raise DGLError( 'Expect a 1-D tensor for argument "{}". But got {}.'.format( name, ret)) return ret
def __set_prior(self, a=6., b=32.): prior_mf = self.p_mff[:, -1] M = prior_mf.shape[0] prior_m = np.ones(M) / M prior_m = prior_m.reshape(-1, self.nu_max) prior_m = np.concatenate([ np.zeros_like(prior_m[:, :self.nu_min]), prior_m[:, self.nu_min:] ], -1).reshape(-1) prior_m /= prior_m.sum() prior_fm = (prior_mf * device_put(prior_m[:, None], self.device)).T prior_c = device_put(jnp.ones(2) / 2, self.device) pars = device_put( jnp.array([[[a, b, 1, 1], [b, a, 1, 1]], [[b, a, 1, 1], [a, b, 1, 1]], [[1, 1, 1000., 1], [1, 1, 1, 1000.]]]), self.device)[None].repeat(self.N, 0) probs = einsum('c,fm->cfm', prior_c, prior_fm)[None].repeat(self.N, 0) self.prior = (probs, pars)
def make_dataset(points_per_class, classes, revolutions=4): np.random.seed(0) N = points_per_class C = classes pi = np.pi X = np.zeros((N * C, 2)) y = np.zeros((N * C, C)) for j in range(C): ix = range(N * j, N * (j + 1)) r = np.linspace(0., 1, N) # radius omega = 2 * pi / C theta_max = revolutions * pi t = np.linspace(omega * j,omega * j + theta_max, N) + np.random.randn(N) * 0.2 # theta X[ix] = np.c_[r*np.cos(t), r*np.sin(t)] y[ix, j] = 1 return jax.device_put(X), jax.device_put(y)