def testInfeed(self, partition_input): if jax.local_device_count() % 2 != 0: raise SkipTest shape = (jax.local_device_count() * 2, 4) # Run computation across all devices so we know which devices to feed. parts = P(jax.local_device_count(), 1) in_parts = parts if partition_input else None infeed_shapes = (jax.ShapedArray(shape, np.float32), jax.ShapedArray((1,), np.float32)) infeed_parts = (parts, None) @partial(sharded_jit, in_parts=in_parts, out_parts=None) def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z[jnp.newaxis] x = np.arange(prod(shape), dtype=np.float32).reshape(shape) y = x + 1 shard_size = shape[0] // jax.local_device_count() y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)] z = jnp.array([3.], dtype=np.float32) assert len(jax.local_devices()) == len(y_shards) for device, y_shard in zip(jax.local_devices(), y_shards): device.transfer_to_infeed((y_shard, z)) # Transfer data to infeed before executing the function. For GPUs, the # execution of the compiled function is blocking, so transferring data # to infeed before executing ensures that the execution does not deadlock # waiting for the infeed data. result = f(x) expected = x @ y.T + z[jnp.newaxis] self.assertAllClose(result, expected, check_dtypes=False)
def testInfeed(self, partition_input): if jax.local_device_count() % 2 != 0: raise SkipTest shape = (jax.local_device_count() * 2, 4) # Run computation across all devices so we know which devices to feed. parts = P(jax.local_device_count(), 1) in_parts = parts if partition_input else None infeed_shapes = (jax.ShapedArray(shape, np.float32), jax.ShapedArray((1, ), np.float32)) infeed_parts = (parts, None) @partial(sharded_jit, in_parts=in_parts, out_parts=None) def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z x = np.arange(prod(shape), dtype=np.float32).reshape(shape) y = x + 1 shard_size = shape[0] // jax.local_device_count() y_shards = [ y[i:i + shard_size] for i in range(0, shape[0], shard_size) ] z = jnp.array([3.], dtype=np.float32) result = f(x) assert len(jax.local_devices()) == len(y_shards) for device, y_shard in zip(jax.local_devices(), y_shards): device.transfer_to_infeed((y_shard, z)) expected = x @ y.T + z self.assertAllClose(result, expected, check_dtypes=False)
def test_get_from_first_device(self): sharded = { 'a': jax.device_put_sharded( list(jnp.arange(16).reshape([jax.local_device_count(), 4])), jax.local_devices()), 'b': jax.device_put_sharded( list(jnp.arange(8).reshape([jax.local_device_count(), 2])), jax.local_devices(), ), } want = { 'a': jnp.arange(4), 'b': jnp.arange(2), } # Get zeroth device content as DeviceArray. device_arrays = utils.get_from_first_device(sharded, as_numpy=False) jax.tree_map(lambda x: self.assertIsInstance(x, jax.xla.DeviceArray), device_arrays) jax.tree_map(np.testing.assert_array_equal, want, device_arrays) # Get the zeroth device content as numpy arrays. numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
def make_initial_state(key): """""" num_devices = jax.device_count() # critic stuff # model params key, sub_key = jax.random.split(key) shared_params, ensemble_params = networks.q_ensemble_init( ensemble_size, sub_key) # replicated_shared_params = jax.tree_map( # lambda x: jnp.array([x] * num_devices), shared_params) replicated_shared_params = jax.device_put_replicated( shared_params, jax.local_devices()) # optim params _, shared_params_optim_state, ensemble_params_optim_state = ensemble_utils.build_ensemble_optimizer( ensemble_size, shared_params, ensemble_params, optax.adam, {'learning_rate': q_lr}) # replicated_shared_params_optim_state = jax.tree_map( # lambda x: jnp.array([x] * num_devices), shared_params_optim_state) replicated_shared_params_optim_state = jax.device_put_replicated( shared_params_optim_state, jax.local_devices()) # policy stuff key, sub_key = jax.random.split(key) policy_params = networks.policy_network.init(sub_key) policy_optimizer_state = policy_optimizer.init(policy_params) # replicated_policy_params = jax.tree_map( # lambda x: jnp.array([x] * num_devices), policy_params) # replicated_policy_optimizer_state = jax.tree_map( # lambda x: jnp.array([x] * num_devices), policy_optimizer_state) replicated_policy_params = jax.device_put_replicated( policy_params, jax.local_devices()) replicated_policy_optimizer_state = jax.device_put_replicated( policy_optimizer_state, jax.local_devices()) state = TrainingState( replicated_policy_optimizer_state= replicated_policy_optimizer_state, replicated_shared_q_optim_state= replicated_shared_params_optim_state, ensemble_q_optim_state=ensemble_params_optim_state, replicated_policy_params=replicated_policy_params, replicated_shared_q_params=replicated_shared_params, ensemble_q_params=ensemble_params, target_replicated_shared_q_params=replicated_shared_params, target_ensemble_q_params=ensemble_params, key=key, ) # entropy stuff if adaptive_entropy_coefficient: state = state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha) # jax.tree_map(lambda t: print(t.shape), replicated_shared_params_optim_state) return state
def test_remote_transfer(self): if jax.device_count() < 2: raise unittest.SkipTest( "Remote transfer requires at lest 2 devices") dev_a, dev_b = jax.local_devices()[:2] if "libtpu" in jax.local_devices()[0].client.platform_version: raise unittest.SkipTest("Test does not yet work on cloud TPU") send_buf = jax.device_put(np.ones((32, )), dev_a) shapes = [send_buf.xla_shape()] (tag, recv_buf), = dev_b.client.make_cross_host_receive_buffers( shapes, dev_b) status, dispatched = send_buf.copy_to_remote_device(tag) self.assertIsNone(status) self.assertTrue(dispatched) self.assertArraysEqual(send_buf, recv_buf)
def test_pmap_update_nested(self): local_device_count = jax.local_device_count() state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x = { 'a': (jnp.arange(15 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 5), 'b': (jnp.arange(6 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 2), } devices = jax.local_devices() state = jax.device_put_replicated(state, devices) pmap_axis_name = 'i' state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) normalized = jax.pmap(running_statistics.normalize)(x, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def sample_with_prompt(self, prompt, rng=None): """Draws prompt-guided samples from the model. # TODO(gandreea): We could handle variable length prompts by assuming the # input prompt to be a list and padding with the out_of_prompt_token. Args: prompt: Iterable over equal-length sequences to use as input for sampling. The prompt is assumed to start with the BOS token. rng: A jax.random.PRNGKey object. Returns: An array of shape (len(prompt), self._length) containing sequences. If variable-length, the sequences are right-padded with the EOS token. """ if rng is None: self._sample_rng, rng = jax.random.split(self._sample_rng) length = self._length + 1 prompt = common_utils.shard(prompt) cache = jax_utils.replicate( self._cache_def.initialize_cache((prompt.shape[1], length))) samples = self._p_sample_step( prompt=prompt, model=self._optimizer.target, cache=cache, rng=jax.random.split(rng, num=len(jax.local_devices())), ) # Remove the BOS token from the sampled sequences. samples = samples[:, :, 1:] # Undo pmap batching samples = jnp.reshape(samples, [-1, self._length]) return samples
def test_device_mismatch(self): devices = jax.devices() if len(devices) < 8: raise unittest.SkipTest("Test requires 8 global devices.") mesh_devices = np.array([[devices[0], devices[2]], [devices[3], devices[1]], [devices[4], devices[6]], [devices[7], devices[5]]]) global_mesh = Mesh(mesh_devices, ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[indices[d]], d) for d in jax.local_devices() ] with self.assertRaisesRegex( ValueError, 'The `global_mesh.local_devices` and `device_buffers` device order' ): GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def f(x): if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): return jax.device_put_replicated(x, jax.local_devices()) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) else: return x
def __init__(self, optimizer_def, devices=None, axis_name='batch'): super().__init__(optimizer_def.hyper_params) if devices is None: devices = jax.local_devices() self.optimizer_def = optimizer_def self.devices = devices self.axis_name = axis_name
def create_device_mesh(mesh_shape: Sequence[int], contiguous_submeshes: bool = False) -> np.ndarray: """Creates a performant device mesh for jax.experimental.maps.mesh. Args: mesh_shape: shape of logical mesh, ordered by increasing network-intensity e.g. [replica, data, mdl] where mdl has the most network communication requirements. contiguous_submeshes: if True, this function will attempt to create a mesh where each process's local devices form a contiguous submesh. This is required when passing non-GlobalDeviceArrays to `pjit` (see the "Multi-process platforms" note of the [pjit documentation](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) for more information on this constraint). A ValueError will be raised if this function can't produce a suitable mesh. Returns: A np.ndarray of jax global devices with mesh_shape as its shape that can be fed into jax.experimental.maps.mesh with good collective performance. """ process_0_devices = jax.local_devices(process_index=0) global_devices = jax.devices() device_kind = global_devices[-1].device_kind return _create_device_mesh(process_0_devices, global_devices, device_kind, mesh_shape, contiguous_submeshes)
def make_initial_state(key): """""" # policy stuff key, sub_key = jax.random.split(key) policy_params = networks.policy_network.init(sub_key) policy_optimizer_state = policy_optimizer.init(policy_params) devices = jax.local_devices() replicated_policy_params = jax.device_put_replicated( policy_params, devices) replicated_optim_state = jax.device_put_replicated( policy_optimizer_state, devices) if use_img_encoder: """ Load pretrained img_encoder_params and do: replicated_img_encoder_params = jax.device_put_replicated( img_encoder_params, devices) """ class EncoderTrainingState(NamedTuple): encoder_params: hk.Params img_encoder_params = {} replicated_img_encoder_params = img_encoder_params raise NotImplementedError('Need to load a checkpoint.') else: img_encoder_params = {} replicated_img_encoder_params = img_encoder_params state = TrainingState( policy_optimizer_state=replicated_optim_state, policy_params=replicated_policy_params, key=key, img_encoder_params=replicated_img_encoder_params) return state
def compute_updates_for_dp(state, graph, labels, subgraphs, node_indices, adjacency_normalization): """Computes gradients for a single batch for differentially private training.""" def subgraph_loss(params, graph, node_labels, subgraph_indices): """Compute loss over this subgraph at the root node.""" subgraph = make_subgraph_from_indices( graph, subgraph_indices, add_reverse_edges=False, adjacency_normalization=adjacency_normalization) subgraph_preds = state.apply_fn(params, subgraph).nodes node_preds = subgraph_preds[0, :] return compute_loss(node_preds, node_labels) # Reshape leading axes for multiple devices. node_labels = reshape_before_pmap(labels[node_indices]) subgraph_indices = reshape_before_pmap(subgraphs[node_indices]) # Compute per-example gradients. per_example_gradient_fn = jax.vmap(jax.grad(subgraph_loss), in_axes=(None, None, 0, 0)) per_example_gradient_fn = jax.pmap(per_example_gradient_fn, axis_name='devices', in_axes=(None, None, 0, 0), devices=jax.local_devices()) grads = per_example_gradient_fn(state.params, graph, node_labels, subgraph_indices) # Undo reshape. grads = jax.tree_map(reshape_after_pmap, grads) # Normalize gradients by batch size. return jax.tree_map(lambda grad: grad / grad.shape[0], grads)
def _device_to_device_funcs(): """Generates device-to-device transfer functions.""" if len(jax.local_devices()) < 2: # device-to-device tests require at least 2 devices. return [] with jax.transfer_guard_host_to_device("allow"): device_arrays = [jnp.ones(1) for _ in range(2)] return [ # (function name, is an explicit transfer?, function) ("device_to_device_jax_device_put", True, lambda: jax.device_put(device_arrays[0], device=jax.local_devices()[1])), ("device_to_device_jax_jit", False, lambda: jax.jit(lambda x: x, device=jax.local_devices()[1]) (device_arrays[1])), ]
def create_device_mesh(mesh_shape: Sequence[int]) -> np.ndarray: """Creates a performant device mesh for jax.experimental.maps.mesh. Args: mesh_shape: shape of logical mesh, ordered by increasing network-intensity e.g. [replica, data, mdl] where mdl has the most network communication requirements. Returns: A np.ndarray of jax devices with mesh_shape as its shape that can be fed into jax.experimental.maps.mesh with good collective performance. """ local_jax_devices_from_process_0 = jax.local_devices(process_index=0) jax_devices = jax.devices() device_kind = jax_devices[-1].device_kind # TODO(zhangqiaorjc): Handle TPU versions other than v4 more generally. if device_kind == _TPU_V3: device_mesh = np.asarray(jax_devices).reshape(mesh_shape) if mesh_shape[-1] == 8: logging.info('Re-order TPUv3 device mesh for better performance.') perm = np.array([0, 1, 2, 3, 6, 7, 4, 5]) device_mesh = device_mesh[:, :, perm] return device_mesh elif device_kind == _TPU_V4: physical_mesh = _jax_devices_order_normalized( local_jax_devices_from_process_0, jax_devices) device_mesh, assignment = _create_device_mesh_for_tpu_v4( physical_mesh, mesh_shape) logging.info('_create_device_mesh_for_tpu_v4 assignment: %s', assignment) return device_mesh else: device_mesh = np.asarray(jax_devices).reshape(mesh_shape) return device_mesh
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.mode == 'train': train.train_and_evaluate(FLAGS.config, FLAGS.workdir) else: predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir, FLAGS.ckpt_path)
def make_hmc_update_eval_fns(net, train_set, test_set, likelihood_fn, prior_fn): """Make update and eval functions for HMC training.""" n_devices = len(jax.local_devices()) def log_prob_and_grad_fn(params): params_p = jax.pmap(lambda _: params)(jnp.arange(n_devices)) log_prob, _, grad = nn_loss.pmap_get_loss_acc_grad( net, params_p, likelihood_fn, prior_fn, train_set) return -log_prob[0], jax.tree_map(lambda g: -g[0], grad) def log_prob_and_acc(params, dataset): params_p = jax.pmap(lambda _: params)(jnp.arange(n_devices)) log_prob, acc = nn_loss.pmap_get_loss_and_acc(net, params_p, likelihood_fn, prior_fn, dataset) return -log_prob[0], acc[0] hmc_update = hmc.make_adaptive_hmc_update(log_prob_and_grad_fn) def update(params, log_prob, state_grad, key, step_size, trajectory_len): params, log_prob, state_grad, step_size, accept_prob = hmc_update( params, log_prob, state_grad, key, step_size, trajectory_len) key, = jax.random.split(key, 1) return params, log_prob, state_grad, step_size, key, accept_prob def evaluate(params): test_log_prob, test_acc = log_prob_and_acc(params, test_set) train_log_prob, train_acc = log_prob_and_acc(params, train_set) return test_log_prob, test_acc, train_log_prob, train_acc return update, evaluate, log_prob_and_grad_fn
def _multi_device_put(x, devices=None): """Memory efficient multi-device replication / broadcast in JAX. JAX uses a ShardedDeviceArray class that holds a list of device buffers on separate devices for use with pmap'd computations. Sharded arrays are explicitly used to eliminate unneccessary inter-device transfer of memory buffers between use in pmap'd computations. The JAX API currently does not have a multi-device 'put' function that copies a buffer onto N devices in a memory-efficient fashion, so we implement our own here. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replicated device_buffers on each local device. """ # Convert _FilledConstants that don't have device_buffer, etc. if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck x = np.array(x) # Calculate the abstract shape of the replicated array. if not devices: devices = jax.local_devices() n_devices = len(devices) x_aval = jax.xla.abstractify(x) broadcast_x_aval = jax.abstract_arrays.ShapedArray( (n_devices, ) + x_aval.shape, x_aval.dtype) # Create copies of the underlying device buffer for each local device. broadcast_buffers = [jax.device_put(x, dv).device_buffer for dv in devices] return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
def testXMapMeshCollectives(self): local_devices = list(jax.local_devices()) if len(local_devices) < 4: raise SkipTest("Test requires at least 4 local devices") def f(a, b): return lax.psum(a * 2, 'a'), b * 4 devices = np.array(local_devices[:4]).reshape((2, 2)) with mesh(devices, ('x', 'y')): fm = xmap(f, in_axes=[A({ 'a': 0, 'b': 1 }), A({'c': 0})], out_axes=[A({'b': 0}), A({'c': 0})], schedule=[ ('a', 'x'), ('b', 'y'), ('c', 'x'), ('a', 'vectorize'), ('b', 'vectorize'), ]) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.log_dir = FLAGS.workdir FLAGS.stderrthreshold = 'info' logging.get_absl_handler().start_logging_to_file() # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'host_id: {jax.host_id()}, host_count: {jax.host_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir') if FLAGS.sample: sample.save_images(sample.generate_sample(FLAGS.config, FLAGS.workdir), 'sample.png') else: train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
def testInfeedThenOutfeedInALoop(self): hcb.stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2)) @jax.jit def f(n): token = lax.create_token(n) token = lax.fori_loop(0, n, doubler, token) return n device = jax.local_devices()[0] n = 10 execution = threading.Thread(target=lambda: f(n)) execution.start() for _ in range(n): x = np.random.randn(3, 4).astype(np.float32) device.transfer_to_infeed((x,)) y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) .with_major_to_minor_layout_if_absent()) self.assertAllClose(y, x * np.float32(2)) execution.join()
def test_complex_dtype(self): if jax.local_devices()[0].platform == "tpu": self.skipTest("Complex dtype not supported by TPU") # This just makes sure we can call the initializers in accordance to the # API and get the right shapes and dtypes out. inits = [ initializers.Constant(42. + 1j * 1729.), initializers.RandomNormal(), initializers.RandomNormal(2.0), initializers.RandomNormal(2. - 3j), initializers.TruncatedNormal(), initializers.TruncatedNormal(2.), initializers.TruncatedNormal(2., 1. - 1j), # Users are supposed to be able to use these. jnp.zeros, jnp.ones, ] shape = (5, 13, 17) dtype = jnp.complex64 for init in inits: generated = init(shape, dtype) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, dtype)
def testOutfeed(self): devices = np.array(jax.local_devices()) nr_devices = len(devices) shape = (nr_devices * 3, nr_devices * 5) def f(x): token = lax.create_token(x) token = lax.outfeed(token, x, partitions=(None,)) token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),)) token = lax.outfeed(token, x, partitions=(P(1, nr_devices),)) return x x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) def dispatch(): with mesh(devices, ['d']): logging.info('Making pjit call') pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x) execution = threading.Thread(target=dispatch) execution.start() def check_outfeed(d, x): y, = d.transfer_from_outfeed( xla_client.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent()) self.assertAllClose(x, y, check_dtypes=True) logging.info('Transfering from outfeed for the pjit call') for didx, d in enumerate(devices): # Transfer the whole array from all devices for replicated. check_outfeed(d, x) # For sharded outfeed, the results are sliced. check_outfeed(d, x[3 * didx:3 * didx + 3, :]) check_outfeed(d, x[:, 5 * didx:5 * didx + 5]) execution.join()
def _replicate(x, devices=None): x = jax.numpy.array(x) if devices is None: devices = jax.local_devices() aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype) buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)
def double_buffer(ds: Iterable[T]) -> Generator[T, None, None]: """Keeps at least two batches on the accelerator. The current GPU allocator design reuses previous allocations. For a training loop this means batches will (typically) occupy the same region of memory as the previous batch. An issue with this is that it means we cannot overlap a host->device copy for the next batch until the previous step has finished and the previous batch has been freed. By double buffering we ensure that there are always two batches on the device. This means that a given batch waits on the N-2'th step to finish and free, meaning that it can allocate and copy the next batch to the accelerator in parallel with the N-1'th step being executed. Args: ds: Iterable of batches of numpy arrays. Yields: Batches of sharded device arrays. """ batch = None devices = jax.local_devices() for next_batch in ds: assert next_batch is not None next_batch = _device_put_sharded(next_batch, devices) if batch is not None: yield batch batch = next_batch if batch is not None: yield batch
def inference_speed_memory(self, batch_size, seq_length): # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length)) key = jax.random.PRNGKey(0) input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size) @jax.jit def ref_step(): out = self.model(input_ids=input_ids) return out[0] if jax.local_devices()[0].platform == 'gpu': nvml.nvmlInit() ref_step().block_until_ready() handle = nvml.nvmlDeviceGetHandleByIndex(0) meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) max_bytes_in_use = meminfo.used memory = Memory(max_bytes_in_use) # shutdown nvml nvml.nvmlShutdown() else: memory = None timeit.repeat("ref_step().block_until_ready()", repeat=1, number=2,globals=locals()) if self.jit: runtimes = timeit.repeat("ref_step().block_until_ready()", repeat=self.repeat,number=3,globals=locals()) else: with jax.disable_jit(): runtimes = timeit.repeat("ref_step().block_until_ready()",repeat=self.repeat,number=3,globals=locals()) return float(np.min(runtimes)/3.0), memory
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') utils.add_gfile_logger(_WORKDIR.value) # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') jax.config.update('jax_log_compiles', True) logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend) logging.info('Using JAX XLA backend %s', jax_xla_backend) logging.info('Config: %s', FLAGS.config) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}') platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, _WORKDIR.value, 'workdir') if FLAGS.config.trainer == 'train': train.train_and_evaluate(FLAGS.config, _WORKDIR.value) elif FLAGS.config.trainer == 'inference_time': inference_time.inference_time(FLAGS.config, _WORKDIR.value) else: raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
def _multi_device_put(x, devices=None): """Memory efficient multi-device replication / broadcast in JAX. JAX uses a ShardedDeviceArray class that holds a list of device buffers on separate devices for use with pmap'd computations. Sharded arrays are explicitly used to eliminate unnecessary inter-device transfer of memory buffers between use in pmap'd computations. The JAX API currently does not have a multi-device 'put' function that copies a buffer onto N devices in a memory-efficient fashion, so we implement our own here. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replicated device_buffers on each local device. """ # Calculate the abstract shape of the replicated array. if not devices: devices = jax.local_devices() # The code below is equivalent to: # jax.api.device_put_sharded(len(devices) * [x], devices) # but it does one PCI transfer and later uses ICI. # TODO(lukaszkaiser): remove once JAX has a core function to do the same. aval = jax.core.unmapped_aval(len(devices), 0, jax.core.raise_to_shaped(jax.core.get_aval(x))) buf, = jax.xla.device_put(x, devices[0]) # assuming single-buf repr rest_bufs = [buf.copy_to_device(d) for d in devices[1:]] return jax.pxla.ShardedDeviceArray(aval, [buf, *rest_bufs])
def testBasic(self): local_devices = list(jax.local_devices()) if len(local_devices) < 4: raise SkipTest("Test requires at least 4 local devices") def f(a, b): return a * 2, b * 4 devices = np.array(local_devices[:4]).reshape((2, 2)) with mesh(devices, ('x', 'y')): fm = xmap(f, in_axes=[{ 0: 'a', 1: 'b' }, ['c', ...]], out_axes=[{ 0: 'a', 1: 'b' }, ['c', ...]], axis_resources={ 'a': 'x', 'b': 'y', 'c': 'x' }) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, a * 2) self.assertAllClose(d, b * 4)
def make_cpu_tensor(shape, dtype=float): import jax from jax import numpy as jnp tiny = jnp.zeros((), dtype=dtype) tiny_cpu = jax.device_put(tiny, jax.local_devices(backend='cpu')[0]) big_cpu = jnp.tile(tiny_cpu, shape) return big_cpu