Example #1
0
  def testInfeed(self, partition_input):
    if jax.local_device_count() % 2 != 0:
      raise SkipTest

    shape = (jax.local_device_count() * 2, 4)
    # Run computation across all devices so we know which devices to feed.
    parts = P(jax.local_device_count(), 1)
    in_parts = parts if partition_input else None
    infeed_shapes = (jax.ShapedArray(shape, np.float32),
                     jax.ShapedArray((1,), np.float32))
    infeed_parts = (parts, None)

    @partial(sharded_jit, in_parts=in_parts, out_parts=None)
    def f(x):
      token = lax.create_token(x)
      (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts)
      return x @ y.T + z[jnp.newaxis]

    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    y = x + 1
    shard_size = shape[0] // jax.local_device_count()
    y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)]
    z = jnp.array([3.], dtype=np.float32)

    assert len(jax.local_devices()) == len(y_shards)
    for device, y_shard in zip(jax.local_devices(), y_shards):
      device.transfer_to_infeed((y_shard, z))
    # Transfer data to infeed before executing the function. For GPUs, the
    # execution of the compiled function is blocking, so transferring data
    # to infeed before executing ensures that the execution does not deadlock
    # waiting for the infeed data.
    result = f(x)

    expected = x @ y.T + z[jnp.newaxis]
    self.assertAllClose(result, expected, check_dtypes=False)
Example #2
0
    def testInfeed(self, partition_input):
        if jax.local_device_count() % 2 != 0:
            raise SkipTest

        shape = (jax.local_device_count() * 2, 4)
        # Run computation across all devices so we know which devices to feed.
        parts = P(jax.local_device_count(), 1)
        in_parts = parts if partition_input else None
        infeed_shapes = (jax.ShapedArray(shape, np.float32),
                         jax.ShapedArray((1, ), np.float32))
        infeed_parts = (parts, None)

        @partial(sharded_jit, in_parts=in_parts, out_parts=None)
        def f(x):
            token = lax.create_token(x)
            (y, z), token = lax.infeed(token,
                                       infeed_shapes,
                                       partitions=infeed_parts)
            return x @ y.T + z

        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
        y = x + 1
        shard_size = shape[0] // jax.local_device_count()
        y_shards = [
            y[i:i + shard_size] for i in range(0, shape[0], shard_size)
        ]
        z = jnp.array([3.], dtype=np.float32)

        result = f(x)
        assert len(jax.local_devices()) == len(y_shards)
        for device, y_shard in zip(jax.local_devices(), y_shards):
            device.transfer_to_infeed((y_shard, z))

        expected = x @ y.T + z
        self.assertAllClose(result, expected, check_dtypes=False)
Example #3
0
    def test_get_from_first_device(self):
        sharded = {
            'a':
            jax.device_put_sharded(
                list(jnp.arange(16).reshape([jax.local_device_count(), 4])),
                jax.local_devices()),
            'b':
            jax.device_put_sharded(
                list(jnp.arange(8).reshape([jax.local_device_count(), 2])),
                jax.local_devices(),
            ),
        }

        want = {
            'a': jnp.arange(4),
            'b': jnp.arange(2),
        }

        # Get zeroth device content as DeviceArray.
        device_arrays = utils.get_from_first_device(sharded, as_numpy=False)
        jax.tree_map(lambda x: self.assertIsInstance(x, jax.xla.DeviceArray),
                     device_arrays)
        jax.tree_map(np.testing.assert_array_equal, want, device_arrays)

        # Get the zeroth device content as numpy arrays.
        numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True)
        jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray),
                     numpy_arrays)
        jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
        def make_initial_state(key):
            """"""
            num_devices = jax.device_count()
            # critic stuff
            # model params
            key, sub_key = jax.random.split(key)
            shared_params, ensemble_params = networks.q_ensemble_init(
                ensemble_size, sub_key)
            # replicated_shared_params = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), shared_params)
            replicated_shared_params = jax.device_put_replicated(
                shared_params, jax.local_devices())

            # optim params
            _, shared_params_optim_state, ensemble_params_optim_state = ensemble_utils.build_ensemble_optimizer(
                ensemble_size, shared_params, ensemble_params, optax.adam,
                {'learning_rate': q_lr})
            # replicated_shared_params_optim_state = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), shared_params_optim_state)
            replicated_shared_params_optim_state = jax.device_put_replicated(
                shared_params_optim_state, jax.local_devices())

            # policy stuff
            key, sub_key = jax.random.split(key)
            policy_params = networks.policy_network.init(sub_key)
            policy_optimizer_state = policy_optimizer.init(policy_params)

            # replicated_policy_params = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), policy_params)
            # replicated_policy_optimizer_state = jax.tree_map(
            #     lambda x: jnp.array([x] * num_devices), policy_optimizer_state)
            replicated_policy_params = jax.device_put_replicated(
                policy_params, jax.local_devices())
            replicated_policy_optimizer_state = jax.device_put_replicated(
                policy_optimizer_state, jax.local_devices())

            state = TrainingState(
                replicated_policy_optimizer_state=
                replicated_policy_optimizer_state,
                replicated_shared_q_optim_state=
                replicated_shared_params_optim_state,
                ensemble_q_optim_state=ensemble_params_optim_state,
                replicated_policy_params=replicated_policy_params,
                replicated_shared_q_params=replicated_shared_params,
                ensemble_q_params=ensemble_params,
                target_replicated_shared_q_params=replicated_shared_params,
                target_ensemble_q_params=ensemble_params,
                key=key,
            )

            # entropy stuff
            if adaptive_entropy_coefficient:
                state = state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=log_alpha)

            # jax.tree_map(lambda t: print(t.shape), replicated_shared_params_optim_state)

            return state
 def test_remote_transfer(self):
     if jax.device_count() < 2:
         raise unittest.SkipTest(
             "Remote transfer requires at lest 2 devices")
     dev_a, dev_b = jax.local_devices()[:2]
     if "libtpu" in jax.local_devices()[0].client.platform_version:
         raise unittest.SkipTest("Test does not yet work on cloud TPU")
     send_buf = jax.device_put(np.ones((32, )), dev_a)
     shapes = [send_buf.xla_shape()]
     (tag, recv_buf), = dev_b.client.make_cross_host_receive_buffers(
         shapes, dev_b)
     status, dispatched = send_buf.copy_to_remote_device(tag)
     self.assertIsNone(status)
     self.assertTrue(dispatched)
     self.assertArraysEqual(send_buf, recv_buf)
  def test_pmap_update_nested(self):
    local_device_count = jax.local_device_count()
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x = {
        'a': (jnp.arange(15 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 5),
        'b': (jnp.arange(6 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 2),
    }

    devices = jax.local_devices()
    state = jax.device_put_replicated(state, devices)
    pmap_axis_name = 'i'
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    normalized = jax.pmap(running_statistics.normalize)(x, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
Example #7
0
    def sample_with_prompt(self, prompt, rng=None):
        """Draws prompt-guided samples from the model.

    # TODO(gandreea): We could handle variable length prompts by assuming the
    #   input prompt to be a list and padding with the out_of_prompt_token.

    Args:
      prompt: Iterable over equal-length sequences to use as input for sampling.
        The prompt is assumed to start with the BOS token.
      rng: A jax.random.PRNGKey object.

    Returns:
      An array of shape (len(prompt), self._length) containing sequences. If
        variable-length, the sequences are right-padded with the EOS token.
    """
        if rng is None:
            self._sample_rng, rng = jax.random.split(self._sample_rng)
        length = self._length + 1
        prompt = common_utils.shard(prompt)
        cache = jax_utils.replicate(
            self._cache_def.initialize_cache((prompt.shape[1], length)))
        samples = self._p_sample_step(
            prompt=prompt,
            model=self._optimizer.target,
            cache=cache,
            rng=jax.random.split(rng, num=len(jax.local_devices())),
        )

        # Remove the BOS token from the sampled sequences.
        samples = samples[:, :, 1:]

        # Undo pmap batching
        samples = jnp.reshape(samples, [-1, self._length])
        return samples
    def test_device_mismatch(self):
        devices = jax.devices()
        if len(devices) < 8:
            raise unittest.SkipTest("Test requires 8 global devices.")
        mesh_devices = np.array([[devices[0], devices[2]],
                                 [devices[3], devices[1]],
                                 [devices[4], devices[6]],
                                 [devices[7], devices[5]]])
        global_mesh = Mesh(mesh_devices, ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x', 'y']
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)
        indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)

        dbs = [
            jax.device_put(global_input_data[indices[d]], d)
            for d in jax.local_devices()
        ]

        with self.assertRaisesRegex(
                ValueError,
                'The `global_mesh.local_devices` and `device_buffers` device order'
        ):
            GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
Example #9
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return jax.device_put_replicated(x, jax.local_devices())
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Example #10
0
 def __init__(self, optimizer_def, devices=None, axis_name='batch'):
   super().__init__(optimizer_def.hyper_params)
   if devices is None:
     devices = jax.local_devices()
   self.optimizer_def = optimizer_def
   self.devices = devices
   self.axis_name = axis_name
Example #11
0
def create_device_mesh(mesh_shape: Sequence[int],
                       contiguous_submeshes: bool = False) -> np.ndarray:
  """Creates a performant device mesh for jax.experimental.maps.mesh.

  Args:
    mesh_shape: shape of logical mesh, ordered by increasing network-intensity
      e.g. [replica, data, mdl] where mdl has the most network communication
      requirements.
    contiguous_submeshes: if True, this function will attempt to create a mesh
      where each process's local devices form a contiguous submesh. This is
      required when passing non-GlobalDeviceArrays to `pjit` (see the
      "Multi-process platforms" note of the [pjit
      documentation](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html)
      for more information on this constraint). A ValueError will be raised if
      this function can't produce a suitable mesh.

  Returns:
    A np.ndarray of jax global devices with mesh_shape as its shape that can be
    fed into jax.experimental.maps.mesh with good collective performance.

  """
  process_0_devices = jax.local_devices(process_index=0)
  global_devices = jax.devices()
  device_kind = global_devices[-1].device_kind
  return _create_device_mesh(process_0_devices, global_devices, device_kind,
                             mesh_shape, contiguous_submeshes)
Example #12
0
    def make_initial_state(key):
      """"""
      # policy stuff
      key, sub_key = jax.random.split(key)
      policy_params = networks.policy_network.init(sub_key)
      policy_optimizer_state = policy_optimizer.init(policy_params)

      devices = jax.local_devices()
      replicated_policy_params = jax.device_put_replicated(
          policy_params, devices)
      replicated_optim_state = jax.device_put_replicated(
          policy_optimizer_state, devices)

      if use_img_encoder:
        """
        Load pretrained img_encoder_params and do:
        replicated_img_encoder_params = jax.device_put_replicated(
            img_encoder_params, devices)
        """
        class EncoderTrainingState(NamedTuple):
          encoder_params: hk.Params
        img_encoder_params = {}
        replicated_img_encoder_params = img_encoder_params
        raise NotImplementedError('Need to load a checkpoint.')
      else:
        img_encoder_params = {}
        replicated_img_encoder_params = img_encoder_params

      state = TrainingState(
          policy_optimizer_state=replicated_optim_state,
          policy_params=replicated_policy_params,
          key=key,
          img_encoder_params=replicated_img_encoder_params)
      return state
Example #13
0
def compute_updates_for_dp(state, graph, labels, subgraphs, node_indices,
                           adjacency_normalization):
    """Computes gradients for a single batch for differentially private training."""
    def subgraph_loss(params, graph, node_labels, subgraph_indices):
        """Compute loss over this subgraph at the root node."""
        subgraph = make_subgraph_from_indices(
            graph,
            subgraph_indices,
            add_reverse_edges=False,
            adjacency_normalization=adjacency_normalization)
        subgraph_preds = state.apply_fn(params, subgraph).nodes
        node_preds = subgraph_preds[0, :]
        return compute_loss(node_preds, node_labels)

    # Reshape leading axes for multiple devices.
    node_labels = reshape_before_pmap(labels[node_indices])
    subgraph_indices = reshape_before_pmap(subgraphs[node_indices])

    # Compute per-example gradients.
    per_example_gradient_fn = jax.vmap(jax.grad(subgraph_loss),
                                       in_axes=(None, None, 0, 0))
    per_example_gradient_fn = jax.pmap(per_example_gradient_fn,
                                       axis_name='devices',
                                       in_axes=(None, None, 0, 0),
                                       devices=jax.local_devices())
    grads = per_example_gradient_fn(state.params, graph, node_labels,
                                    subgraph_indices)

    # Undo reshape.
    grads = jax.tree_map(reshape_after_pmap, grads)

    # Normalize gradients by batch size.
    return jax.tree_map(lambda grad: grad / grad.shape[0], grads)
Example #14
0
def _device_to_device_funcs():
  """Generates device-to-device transfer functions."""
  if len(jax.local_devices()) < 2:
    # device-to-device tests require at least 2 devices.
    return []

  with jax.transfer_guard_host_to_device("allow"):
    device_arrays = [jnp.ones(1) for _ in range(2)]
  return [
      # (function name, is an explicit transfer?, function)
      ("device_to_device_jax_device_put", True,
       lambda: jax.device_put(device_arrays[0], device=jax.local_devices()[1])),
      ("device_to_device_jax_jit", False,
       lambda: jax.jit(lambda x: x, device=jax.local_devices()[1])
       (device_arrays[1])),
  ]
Example #15
0
def create_device_mesh(mesh_shape: Sequence[int]) -> np.ndarray:
  """Creates a performant device mesh for jax.experimental.maps.mesh.

  Args:
    mesh_shape: shape of logical mesh, ordered by increasing network-intensity
      e.g. [replica, data, mdl] where mdl has the most network communication
      requirements.

  Returns:
    A np.ndarray of jax devices with mesh_shape as its shape that can be fed
    into jax.experimental.maps.mesh with good collective performance.
  """
  local_jax_devices_from_process_0 = jax.local_devices(process_index=0)
  jax_devices = jax.devices()
  device_kind = jax_devices[-1].device_kind
  # TODO(zhangqiaorjc): Handle TPU versions other than v4 more generally.
  if device_kind == _TPU_V3:
    device_mesh = np.asarray(jax_devices).reshape(mesh_shape)
    if mesh_shape[-1] == 8:
      logging.info('Re-order TPUv3 device mesh for better performance.')
      perm = np.array([0, 1, 2, 3, 6, 7, 4, 5])
      device_mesh = device_mesh[:, :, perm]
    return device_mesh
  elif device_kind == _TPU_V4:
    physical_mesh = _jax_devices_order_normalized(
        local_jax_devices_from_process_0, jax_devices)
    device_mesh, assignment = _create_device_mesh_for_tpu_v4(
        physical_mesh, mesh_shape)
    logging.info('_create_device_mesh_for_tpu_v4 assignment: %s', assignment)
    return device_mesh
  else:
    device_mesh = np.asarray(jax_devices).reshape(mesh_shape)
    return device_mesh
Example #16
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.mode == 'train':
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
    else:
        predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir,
                                     FLAGS.ckpt_path)
Example #17
0
def make_hmc_update_eval_fns(net, train_set, test_set, likelihood_fn,
                             prior_fn):
    """Make update and eval functions for HMC training."""
    n_devices = len(jax.local_devices())

    def log_prob_and_grad_fn(params):
        params_p = jax.pmap(lambda _: params)(jnp.arange(n_devices))
        log_prob, _, grad = nn_loss.pmap_get_loss_acc_grad(
            net, params_p, likelihood_fn, prior_fn, train_set)
        return -log_prob[0], jax.tree_map(lambda g: -g[0], grad)

    def log_prob_and_acc(params, dataset):
        params_p = jax.pmap(lambda _: params)(jnp.arange(n_devices))
        log_prob, acc = nn_loss.pmap_get_loss_and_acc(net, params_p,
                                                      likelihood_fn, prior_fn,
                                                      dataset)
        return -log_prob[0], acc[0]

    hmc_update = hmc.make_adaptive_hmc_update(log_prob_and_grad_fn)

    def update(params, log_prob, state_grad, key, step_size, trajectory_len):
        params, log_prob, state_grad, step_size, accept_prob = hmc_update(
            params, log_prob, state_grad, key, step_size, trajectory_len)
        key, = jax.random.split(key, 1)
        return params, log_prob, state_grad, step_size, key, accept_prob

    def evaluate(params):
        test_log_prob, test_acc = log_prob_and_acc(params, test_set)
        train_log_prob, train_acc = log_prob_and_acc(params, train_set)
        return test_log_prob, test_acc, train_log_prob, train_acc

    return update, evaluate, log_prob_and_grad_fn
Example #18
0
def _multi_device_put(x, devices=None):
    """Memory efficient multi-device replication / broadcast in JAX.

  JAX uses a ShardedDeviceArray class that holds a list of device buffers
  on separate devices for use with pmap'd computations.  Sharded arrays
  are explicitly used to eliminate unneccessary inter-device transfer of
  memory buffers between use in pmap'd computations.  The JAX API currently
  does not have a multi-device 'put' function that copies a buffer onto
  N devices in a memory-efficient fashion, so we implement our own here.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.

  Returns:
    A ShardedDeviceArray with
    dtype = x.dtype and shape = (n_devices,) + x.shape
    that's backed by replicated device_buffers on each local device.
  """
    # Convert _FilledConstants that don't have device_buffer, etc.
    if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
        x = np.array(x)
    # Calculate the abstract shape of the replicated array.
    if not devices:
        devices = jax.local_devices()
    n_devices = len(devices)
    x_aval = jax.xla.abstractify(x)
    broadcast_x_aval = jax.abstract_arrays.ShapedArray(
        (n_devices, ) + x_aval.shape, x_aval.dtype)
    # Create copies of the underlying device buffer for each local device.
    broadcast_buffers = [jax.device_put(x, dv).device_buffer for dv in devices]
    return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
Example #19
0
    def testXMapMeshCollectives(self):
        local_devices = list(jax.local_devices())
        if len(local_devices) < 4:
            raise SkipTest("Test requires at least 4 local devices")

        def f(a, b):
            return lax.psum(a * 2, 'a'), b * 4

        devices = np.array(local_devices[:4]).reshape((2, 2))
        with mesh(devices, ('x', 'y')):
            fm = xmap(f,
                      in_axes=[A({
                          'a': 0,
                          'b': 1
                      }), A({'c': 0})],
                      out_axes=[A({'b': 0}), A({'c': 0})],
                      schedule=[
                          ('a', 'x'),
                          ('b', 'y'),
                          ('c', 'x'),
                          ('a', 'vectorize'),
                          ('b', 'vectorize'),
                      ])
            ashape = (16, 8, 5)
            a = jnp.arange(np.prod(ashape)).reshape(ashape)
            bshape = (2, 7)
            b = jnp.arange(np.prod(bshape)).reshape(bshape)
            c, d = fm(a, b)
            self.assertAllClose(c, (a * 2).sum(0))
            self.assertAllClose(d, b * 4)
Example #20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.log_dir = FLAGS.workdir
    FLAGS.stderrthreshold = 'info'
    logging.get_absl_handler().start_logging_to_file()

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.sample:
        sample.save_images(sample.generate_sample(FLAGS.config, FLAGS.workdir),
                           'sample.png')
    else:
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
Example #21
0
  def testInfeedThenOutfeedInALoop(self):
    hcb.stop_outfeed_receiver()

    def doubler(_, token):
      y, token = lax.infeed(
          token, shape=jax.ShapedArray((3, 4), jnp.float32))
      return lax.outfeed(token, y * np.float32(2))

    @jax.jit
    def f(n):
      token = lax.create_token(n)
      token = lax.fori_loop(0, n, doubler, token)
      return n

    device = jax.local_devices()[0]
    n = 10
    execution = threading.Thread(target=lambda: f(n))
    execution.start()
    for _ in range(n):
      x = np.random.randn(3, 4).astype(np.float32)
      device.transfer_to_infeed((x,))
      y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,))
                                        .with_major_to_minor_layout_if_absent())
      self.assertAllClose(y, x * np.float32(2))
    execution.join()
Example #22
0
    def test_complex_dtype(self):
        if jax.local_devices()[0].platform == "tpu":
            self.skipTest("Complex dtype not supported by TPU")
        # This just makes sure we can call the initializers in accordance to the
        # API and get the right shapes and dtypes out.
        inits = [
            initializers.Constant(42. + 1j * 1729.),
            initializers.RandomNormal(),
            initializers.RandomNormal(2.0),
            initializers.RandomNormal(2. - 3j),
            initializers.TruncatedNormal(),
            initializers.TruncatedNormal(2.),
            initializers.TruncatedNormal(2., 1. - 1j),

            # Users are supposed to be able to use these.
            jnp.zeros,
            jnp.ones,
        ]

        shape = (5, 13, 17)

        dtype = jnp.complex64
        for init in inits:
            generated = init(shape, dtype)
            self.assertEqual(generated.shape, shape)
            self.assertEqual(generated.dtype, dtype)
Example #23
0
  def testOutfeed(self):
    devices = np.array(jax.local_devices())
    nr_devices = len(devices)
    shape = (nr_devices * 3, nr_devices * 5)

    def f(x):
      token = lax.create_token(x)
      token = lax.outfeed(token, x, partitions=(None,))
      token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
      token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
      return x

    x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)

    def dispatch():
      with mesh(devices, ['d']):
        logging.info('Making pjit call')
        pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x)
    execution = threading.Thread(target=dispatch)
    execution.start()

    def check_outfeed(d, x):
      y, = d.transfer_from_outfeed(
          xla_client.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent())
      self.assertAllClose(x, y, check_dtypes=True)

    logging.info('Transfering from outfeed for the pjit call')
    for didx, d in enumerate(devices):
      # Transfer the whole array from all devices for replicated.
      check_outfeed(d, x)
      # For sharded outfeed, the results are sliced.
      check_outfeed(d, x[3 * didx:3 * didx + 3, :])
      check_outfeed(d, x[:, 5 * didx:5 * didx + 5])

    execution.join()
Example #24
0
def _replicate(x, devices=None):
    x = jax.numpy.array(x)
    if devices is None:
        devices = jax.local_devices()
    aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype)
    buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices]
    return jax.pxla.ShardedDeviceArray(aval, buffers)
Example #25
0
def double_buffer(ds: Iterable[T]) -> Generator[T, None, None]:
    """Keeps at least two batches on the accelerator.

  The current GPU allocator design reuses previous allocations. For a training
  loop this means batches will (typically) occupy the same region of memory as
  the previous batch. An issue with this is that it means we cannot overlap a
  host->device copy for the next batch until the previous step has finished and
  the previous batch has been freed.

  By double buffering we ensure that there are always two batches on the device.
  This means that a given batch waits on the N-2'th step to finish and free,
  meaning that it can allocate and copy the next batch to the accelerator in
  parallel with the N-1'th step being executed.

  Args:
    ds: Iterable of batches of numpy arrays.

  Yields:
    Batches of sharded device arrays.
  """
    batch = None
    devices = jax.local_devices()
    for next_batch in ds:
        assert next_batch is not None
        next_batch = _device_put_sharded(next_batch, devices)
        if batch is not None:
            yield batch
        batch = next_batch
    if batch is not None:
        yield batch
Example #26
0
 def inference_speed_memory(self, batch_size, seq_length):
     # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
     key = jax.random.PRNGKey(0)
     input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
     @jax.jit
     def ref_step():
         out = self.model(input_ids=input_ids)
         return out[0]
     if jax.local_devices()[0].platform == 'gpu':
         nvml.nvmlInit()
         ref_step().block_until_ready()
         handle = nvml.nvmlDeviceGetHandleByIndex(0)
         meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
         max_bytes_in_use = meminfo.used
         memory = Memory(max_bytes_in_use)
         # shutdown nvml
         nvml.nvmlShutdown()
     else:
         memory = None
     timeit.repeat("ref_step().block_until_ready()", repeat=1, number=2,globals=locals())
     if self.jit:
         runtimes = timeit.repeat("ref_step().block_until_ready()", repeat=self.repeat,number=3,globals=locals())
     else:
         with jax.disable_jit():
             runtimes = timeit.repeat("ref_step().block_until_ready()",repeat=self.repeat,number=3,globals=locals())
     return float(np.min(runtimes)/3.0), memory
Example #27
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    utils.add_gfile_logger(_WORKDIR.value)

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    jax.config.update('jax_log_compiles', True)

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else
                       FLAGS.jax_xla_backend)
    logging.info('Using JAX XLA backend %s', jax_xla_backend)

    logging.info('Config: %s', FLAGS.config)

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, 'workdir')

    if FLAGS.config.trainer == 'train':
        train.train_and_evaluate(FLAGS.config, _WORKDIR.value)
    elif FLAGS.config.trainer == 'inference_time':
        inference_time.inference_time(FLAGS.config, _WORKDIR.value)
    else:
        raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
Example #28
0
def _multi_device_put(x, devices=None):
  """Memory efficient multi-device replication / broadcast in JAX.

  JAX uses a ShardedDeviceArray class that holds a list of device buffers
  on separate devices for use with pmap'd computations.  Sharded arrays
  are explicitly used to eliminate unnecessary inter-device transfer of
  memory buffers between use in pmap'd computations.  The JAX API currently
  does not have a multi-device 'put' function that copies a buffer onto
  N devices in a memory-efficient fashion, so we implement our own here.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.

  Returns:
    A ShardedDeviceArray with
    dtype = x.dtype and shape = (n_devices,) + x.shape
    that's backed by replicated device_buffers on each local device.
  """
  # Calculate the abstract shape of the replicated array.
  if not devices:
    devices = jax.local_devices()
  # The code below is equivalent to:
  #   jax.api.device_put_sharded(len(devices) * [x], devices)
  # but it does one PCI transfer and later uses ICI.
  # TODO(lukaszkaiser): remove once JAX has a core function to do the same.
  aval = jax.core.unmapped_aval(len(devices), 0,
                                jax.core.raise_to_shaped(jax.core.get_aval(x)))
  buf, = jax.xla.device_put(x, devices[0])  # assuming single-buf repr
  rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
  return jax.pxla.ShardedDeviceArray(aval, [buf, *rest_bufs])
Example #29
0
    def testBasic(self):
        local_devices = list(jax.local_devices())
        if len(local_devices) < 4:
            raise SkipTest("Test requires at least 4 local devices")

        def f(a, b):
            return a * 2, b * 4

        devices = np.array(local_devices[:4]).reshape((2, 2))
        with mesh(devices, ('x', 'y')):
            fm = xmap(f,
                      in_axes=[{
                          0: 'a',
                          1: 'b'
                      }, ['c', ...]],
                      out_axes=[{
                          0: 'a',
                          1: 'b'
                      }, ['c', ...]],
                      axis_resources={
                          'a': 'x',
                          'b': 'y',
                          'c': 'x'
                      })
            ashape = (16, 8, 5)
            a = jnp.arange(np.prod(ashape)).reshape(ashape)
            bshape = (2, 7)
            b = jnp.arange(np.prod(bshape)).reshape(bshape)
            c, d = fm(a, b)
            self.assertAllClose(c, a * 2)
            self.assertAllClose(d, b * 4)
Example #30
0
def make_cpu_tensor(shape, dtype=float):
    import jax
    from jax import numpy as jnp
    tiny = jnp.zeros((), dtype=dtype)
    tiny_cpu = jax.device_put(tiny, jax.local_devices(backend='cpu')[0])
    big_cpu = jnp.tile(tiny_cpu, shape)
    return big_cpu