Exemple #1
0
  def testInfeed(self, partition_input):
    if jax.local_device_count() % 2 != 0:
      raise SkipTest

    shape = (jax.local_device_count() * 2, 4)
    # Run computation across all devices so we know which devices to feed.
    parts = P(jax.local_device_count(), 1)
    in_parts = parts if partition_input else None
    infeed_shapes = (jax.ShapedArray(shape, np.float32),
                     jax.ShapedArray((1,), np.float32))
    infeed_parts = (parts, None)

    @partial(sharded_jit, in_parts=in_parts, out_parts=None)
    def f(x):
      token = lax.create_token(x)
      (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts)
      return x @ y.T + z[jnp.newaxis]

    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    y = x + 1
    shard_size = shape[0] // jax.local_device_count()
    y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)]
    z = jnp.array([3.], dtype=np.float32)

    assert len(jax.local_devices()) == len(y_shards)
    for device, y_shard in zip(jax.local_devices(), y_shards):
      device.transfer_to_infeed((y_shard, z))
    # Transfer data to infeed before executing the function. For GPUs, the
    # execution of the compiled function is blocking, so transferring data
    # to infeed before executing ensures that the execution does not deadlock
    # waiting for the infeed data.
    result = f(x)

    expected = x @ y.T + z[jnp.newaxis]
    self.assertAllClose(result, expected, check_dtypes=False)
Exemple #2
0
 def f(x):
   token = lax.create_token(x)
   (y,), token = lax.infeed(
       token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
   (z,), _ = lax.infeed(
       token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
   return x + y + z
Exemple #3
0
    def testInfeed(self, partition_input):
        if jax.local_device_count() % 2 != 0:
            raise SkipTest

        shape = (jax.local_device_count() * 2, 4)
        # Run computation across all devices so we know which devices to feed.
        parts = P(jax.local_device_count(), 1)
        in_parts = parts if partition_input else None
        infeed_shapes = (jax.ShapedArray(shape, np.float32),
                         jax.ShapedArray((1, ), np.float32))
        infeed_parts = (parts, None)

        @partial(sharded_jit, in_parts=in_parts, out_parts=None)
        def f(x):
            token = lax.create_token(x)
            (y, z), token = lax.infeed(token,
                                       infeed_shapes,
                                       partitions=infeed_parts)
            return x @ y.T + z

        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
        y = x + 1
        shard_size = shape[0] // jax.local_device_count()
        y_shards = [
            y[i:i + shard_size] for i in range(0, shape[0], shard_size)
        ]
        z = jnp.array([3.], dtype=np.float32)

        result = f(x)
        assert len(jax.local_devices()) == len(y_shards)
        for device, y_shard in zip(jax.local_devices(), y_shards):
            device.transfer_to_infeed((y_shard, z))

        expected = x @ y.T + z
        self.assertAllClose(result, expected, check_dtypes=False)
 def host_loop_eval_step(model, state, metrics):
     token = lax.create_token(metrics['samples'])
     batch, token = lax.infeed(
         token,
         shape=(jax.ShapedArray(eval_input_shape, model_dtype),
                jax.ShapedArray((device_eval_batch_size, ), jnp.int32)))
     metrics = eval_step(model, state, batch, metrics, image_format,
                         space_to_depth)
     return metrics
Exemple #5
0
    def f_for_jit(x):
      token = lax.create_token(x)
      (y,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))
      (z,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))
      (w,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))

      return x + y + z + w
 def device_train_loop_body(args):
     optimizer, state, metrics, token, step, epoch = args
     (images, labels), token = lax.infeed(
         token,
         shape=(jax.ShapedArray(train_input_shape, model_dtype),
                jax.ShapedArray((device_batch_size, ), jnp.int32)))
     batch = {'image': images, 'label': labels}
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn)
     step += 1
     return optimizer, state, metrics, token, step, epoch
 def host_loop_train_step(optimizer, state, metrics):
     token = lax.create_token(optimizer.state[0].step)
     batch, token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   train_input_shape, model_dtype),
                                      jax.ShapedArray((device_batch_size, ),
                                                      jnp.int32)))
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn,
                                            image_format, space_to_depth)
     return optimizer, state, metrics
 def device_train_loop_body(args):
     optimizer, state, metrics, token, step, loop = args
     batch, token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   train_input_shape, model_dtype),
                                      jax.ShapedArray((device_batch_size, ),
                                                      jnp.int32)))
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn,
                                            image_format, space_to_depth)
     step += 1
     return optimizer, state, metrics, token, step, loop
Exemple #9
0
    def scan_fn(broadcast_in, init, *args):
        xs = jax.tree_multimap(transpose_to_front, in_axes, args)

        def body_fn(c, xs, init_mode=False):
            # inject constants
            xs = jax.tree_multimap(
                lambda ax, arg, x: (arg if ax is broadcast else x), in_axes,
                args, xs)
            broadcast_out, c, ys = fn(broadcast_in, c, *xs)

            if init_mode:
                ys = jax.tree_multimap(
                    lambda ax, y: (y if ax is broadcast else ()), out_axes, ys)
                return broadcast_out, ys
            else:
                ys = jax.tree_multimap(
                    lambda ax, y: (() if ax is broadcast else y), out_axes, ys)
                return c, ys

        broadcast_body = functools.partial(body_fn, init_mode=True)

        carry_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init)
        scan_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs)
        input_pvals = (carry_pvals, scan_pvals)
        in_pvals, in_tree = jax.tree_flatten(input_pvals)
        f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
            lu.wrap_init(broadcast_body), in_tree)
        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)

        out_flat = []
        for pv, const in out_pvals:
            if pv is not None:
                raise ValueError(
                    'broadcasted variable has a data dependency on the scan body.'
                )
            out_flat.append(const)
        broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)

        c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse)
        ys = jax.tree_multimap(transpose_from_front, out_axes, ys)
        ys = jax.tree_multimap(
            lambda ax, const, y: (const if ax is broadcast else y), out_axes,
            constants_out, ys)
        return broadcast_in, c, ys
Exemple #10
0
 def device_train_loop_body(args):
     """On-device loop body."""
     optimizer, dropout_rngs, metrics, token, step, epoch = args
     # Ordering input data from infeed requires threading a symbolic token
     # through the computation.
     input_data, token = lax.infeed(token,
                                    shape=tuple([
                                        jax.ShapedArray(s, jnp.int32)
                                        for s in device_train_input_shape
                                    ]))
     # Rebuild input dict from infeed data tuple.
     batch = {k: v for k, v in zip(train_keys, input_data)}
     # Run the train_step function and return the loop state.
     optimizer, metrics, dropout_rngs = train_lib.train_step(
         optimizer,
         batch,
         metrics,
         dropout_rngs,
         train_config,
         learning_rate_fn,
         num_microbatches=CFG.microbatches,
         label_smoothing=CFG.label_smoothing,
         z_loss=CFG.z_loss)
     step += 1
     return optimizer, dropout_rngs, metrics, token, step, epoch
Exemple #11
0
def _replicate(x, devices=None):
    x = jax.numpy.array(x)
    if devices is None:
        devices = jax.local_devices()
    aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype)
    buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices]
    return jax.pxla.ShardedDeviceArray(aval, buffers)
Exemple #12
0
def partial_eval_by_shape(fn, input_spec, *args, **kwargs):
  """Lazily evaluate a function by using the shapes of the inputs.

  This function is similar to `jax.eval_shape` with the key difference that
  function outputs that can be computed without a concrete value of the
  inputs are returned as is instead of only the shape. See for example
  `module.init_by_shape` where this functionality is used to initialize a
  model without using input data lr computation.

  Args:
    fn: the function to be lazily evaluated.
    input_spec: an iterable of shapes or (shape, dtype) tuples specifying the
      shape and type of the inputs. If unspecified the dtype is float32.
    *args: other arguments passed to the module's apply function
    **kwargs: keyword arguments passed to the module's apply function
  Returns:
    A pair consisting of the model output and an instance of Model
  """
  # output cannot be returned in lazy_create because jax.eval_shape will only
  # return the shape and dtype.
  # TODO(mattjj,jheek): use a public JAX API
  f = lambda *inputs: fn(*inputs, *args, **kwargs)
  input_structs = [_parse_spec(spec) for spec in input_spec]
  inputs_flat, in_tree = jax.tree_flatten(input_structs)
  f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
  in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
              for x in inputs_flat]

  if _is_omnistaging:
    _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
  else:
    _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)
  out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype)
              for pv, const in out_pvals]
  return jax.tree_unflatten(out_tree(), out_flat)
Exemple #13
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
Exemple #14
0
 def device_train_loop_body(args):
   """Device training loop body."""
   (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token,
    step, epoch, num_steps_per_epoch) = args
   device_batch_size = FLAGS.train_batch_size // jax.device_count()
   input_shape = [device_batch_size, FLAGS.max_seq_length]
   input_shape_pred = [device_batch_size, FLAGS.max_predictions_per_seq]
   (input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids,
    masked_lm_weights, next_sentence_labels), token = lax.infeed(
        token,
        shape=(jax.ShapedArray(input_shape, jnp.int32),
               jax.ShapedArray(input_shape, jnp.int32),
               jax.ShapedArray(input_shape, jnp.int32),
               jax.ShapedArray(input_shape_pred, jnp.int32),
               jax.ShapedArray(input_shape_pred, jnp.int32),
               jax.ShapedArray(input_shape_pred, jnp.float32),
               jax.ShapedArray([device_batch_size, 1], jnp.int32)))
   inputs = [input_ids, input_mask, segment_ids, masked_lm_positions]
   labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels]
   optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng = train_step(
       optimizer,
       inputs,
       labels,
       learning_rate_fn,
       dropout_rng=new_dropout_rng)
   step += 1
   return (optimizer, total_loss, lm_loss, sentence_loss,
           new_dropout_rng, token, step, epoch, num_steps_per_epoch)
Exemple #15
0
 def f_for_pjit(x):
     token = lax.create_token(x)
     # A replicated infeed
     (y, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(None, ))
     # An infeed sharded on first axis
     (z, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(P(nr_devices, 1), ))
     # An infeed sharded on second axis
     (w, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(P(1, nr_devices), ))
     return x + y + z + w
Exemple #16
0
  def testInfeedPytree(self):

    x = np.float32(1.5)
    y = np.reshape(np.arange(12, dtype=np.int16), (3, 4))
    to_infeed = dict(a=x, b=y)
    to_infeed_shape = dict(a=jax.ShapedArray((), dtype=np.float32),
                           b=jax.ShapedArray((3, 4), dtype=np.int16))
    @jax.jit
    def f(x):
      token = lax.create_token(x)
      res, token = lax.infeed(token, shape=to_infeed_shape)
      return res

    device = jax.local_devices()[0]
    # We must transfer the flattened data, as a tuple!!!
    flat_to_infeed, _ = jax.tree_flatten(to_infeed)
    device.transfer_to_infeed(tuple(flat_to_infeed))
    self.assertAllClose(f(x), to_infeed)
Exemple #17
0
def _replicate(x, devices=None):
  x = jax.numpy.array(x)
  if devices is None:
    # match the default device assignments used in pmap:
    # for single-host, that's the XLA default device assignment
    # for multi-host, it's the order of jax.local_devices()
    if jax.host_count() == 1:
      devices = [d for d in xb.get_backend().get_default_device_assignment(
          jax.device_count()) if d.host_id == jax.host_id()]
    else:
      devices = jax.local_devices()
  aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype)
  buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices]
  return jax.pxla.ShardedDeviceArray(aval, buffers)
Exemple #18
0
 def device_train_loop_body(args):
     optimizer, dropout_rngs, metrics, token, step, epoch = args
     input_data, token = lax.infeed(token,
                                    shape=tuple([
                                        jax.ShapedArray(
                                            device_train_input_shape,
                                            jnp.int32) for _ in train_keys
                                    ]))
     batch = {k: v for k, v in zip(train_keys, input_data)}
     optimizer, metrics, dropout_rngs = train_step(optimizer,
                                                   batch,
                                                   metrics,
                                                   learning_rate_fn,
                                                   dropout_rng=dropout_rngs)
     step += 1
     return optimizer, dropout_rngs, metrics, token, step, epoch
Exemple #19
0
def abstract_single_value(value):
    if isinstance(value, jnp.ndarray):
        value = jax.ShapedArray(np.shape(value), np.result_type(value))
        return pe.PartialVal.unknown(value)
    else:
        return value
Exemple #20
0
    def _compute_out_shapes(self, ins, outs):
        """
        Compute the shapes of outputs based on those of the inputs.

        Parameters
        ----------
        ins : dict
            Dict of input metadata containing input shapes.
        outs : dict
            Dict of output metadata that will be updated with shape information.
        """
        need_shape = []
        for name, ometa in outs.items():
            try:
                ometa['shape']
            except KeyError:
                need_shape.append(name)

        args = []
        static_argnums = []
        for i, (name, meta) in enumerate(ins.items()):
            if 'is_option' in meta and meta['is_option']:
                if 'default' in meta:
                    val = meta['default']
                elif 'values' in meta:
                    val = meta['values'][0]
                else:
                    val = None
                args.append(val)
                static_argnums.append(i)
                continue
            if meta['val'] is not None:
                args.append(meta['val'])
            else:
                try:
                    shp = meta['shape']
                except KeyError:
                    if 'resid' not in meta:  # this is an input, not a state
                        raise RuntimeError(
                            f"Can't determine shape of input '{name}'.")
                else:
                    if jax is not None:
                        shp = None if shp is None else _shape2tuple(shp)
                        args.append(jax.ShapedArray(shp, dtype=np.float64))

        # compute shapes as a check against shapes in metadata (if any)
        if jax is not None:
            try:
                # must replace numpy with jax numpy when making jaxpr.
                with jax_context(self._f.__globals__):
                    v = jax.make_jaxpr(self._f, static_argnums)(*args)
            except Exception as err:
                if need_shape:
                    raise RuntimeError(
                        f"Failed to determine the output shapes "
                        f"based on the input shapes. The error was: {err}.  To "
                        "avoid this error, add return value metadata that "
                        "specifies the shapes of the return values to the function."
                    )
                warnings.warn(
                    "Failed to determine the output shapes based on the input "
                    "shapes in order to check the provided metadata values. The"
                    f" error was: {err}.")
            else:
                for val, name in zip(v.out_avals, outs):
                    oldshape = outs[name].get('shape')
                    if oldshape is not None and _shape2tuple(
                            oldshape) != val.shape:
                        raise RuntimeError(
                            f"shape from metadata for return value "
                            f"'{name}' of {oldshape} doesn't match computed "
                            f"shape of {val.shape}.")
                    outs[name]['shape'] = val.shape
                need_shape = []

        if need_shape:  # output shapes weren't provided by user or by jax
            shape = self._output_defaults['shape']
            warnings.warn(
                f"Return values {need_shape} have unspecified shape so are assumed to "
                f"have shape {shape}.")
            for name in need_shape:
                outs[name]['shape'] = shape
Exemple #21
0
    def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs):
        # split rngs
        split_fn = lambda rng: random.split(rng, length)
        broadcast_rngs_xs = []
        scan_rngs_xs = []
        for rng_groups in rng_groups_xs:
            broadcast_rngs_xs.append(
                tuple(rng_group
                      for rng_group, split in zip(rng_groups, rng_splits)
                      if not split))
            scan_rngs_xs.append(
                tuple(
                    jax.tree_map(split_fn, rng_group)
                    for rng_group, split in zip(rng_groups, rng_splits)
                    if split))

        def body(carry, xs, init_mode=False):
            carry_vars_xs, c = carry
            scan_vars_xs, scan_rngs_xs, x = xs
            variable_groups_xs = combine(scan_vars_xs, carry_vars_xs,
                                         broadcast_vars_xs)
            rng_groups_xs = []
            for broadcast_rngs, scan_rngs in zip(broadcast_rngs_xs,
                                                 scan_rngs_xs):
                rng_groups_xs.append(broadcast_rngs + scan_rngs)
            scope = scope_fn(variable_groups_xs, rng_groups_xs)
            carry, y = fn(scope, c, x)
            out_vars = repack_fn(scope)
            scan_vars_xs, carry_vars_out_xs, broadcast_vars_out_xs = split(
                out_vars, 1)

            # TODO(jheek) more informative error check
            def check_shapes(c_in, c_out):
                if not isinstance(c_in, jnp.ndarray) or not isinstance(
                        c_out, jnp.ndarray):
                    return
                if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype(
                        c_in) != jnp.dtype(c_out):
                    raise ValueError()

            try:
                jax.tree_multimap(check_shapes, carry_vars_xs,
                                  carry_vars_out_xs)
            except ValueError:
                raise ValueError(
                    'carry variables must have the same shape and dtype before and after scan.'
                )

            if init_mode:
                return broadcast_vars_out_xs
            else:
                return (carry_vars_out_xs, carry), (scan_vars_xs, y)

        broadcast_body = functools.partial(body, init_mode=True)

        scan_vars_xs, carry_vars_xs, broadcast_vars_xs = split(
            variable_groups_xs, 0)
        carry0 = (carry_vars_xs, init_carry)
        xxs = (scan_vars_xs, scan_rngs_xs, xs)

        # use partial evaluation to find the variables that are broadcasted out
        # an error is thrown if a broadcasted output has a dependency on any scan variables
        carry_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)),
            carry0)
        scan_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(x.shape[1:], x.dtype)), xxs)
        input_pvals = (carry_pvals, scan_pvals)
        in_pvals, in_tree = jax.tree_flatten(input_pvals)
        f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
            lu.wrap_init(broadcast_body), in_tree)

        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
        # _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)

        out_flat = []
        for pv, const in out_pvals:
            if pv is not None:
                raise ValueError(
                    'broadcasted variable has a data dependency on the scan body.'
                )
            out_flat.append(const)

        (carry_vars_xs, carry), (scan_vars_xs, ys) = lax.scan(body,
                                                              carry0,
                                                              xxs,
                                                              length=length,
                                                              reverse=reverse)

        broadcast_vars_xs = jax.tree_unflatten(out_tree(), out_flat)

        out_vars_xs = combine(carry_vars_xs, scan_vars_xs, broadcast_vars_xs)
        return (carry, ys), out_vars_xs
Exemple #22
0
 def _replicate(x):
     """Replicate an object on each device."""
     x = jnp.array(x)
     aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype)
     buffers = [jax.interpreters.xla.device_put(x, d) for d in devices]
     return jax.pxla.ShardedDeviceArray(aval, buffers)
Exemple #23
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), np.float32))
     token = lax.outfeed(token, y + onp.float32(1))
     return lax.tie_in(token, x - 1)
Exemple #24
0
 def doubler(_, token):
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   return lax.outfeed(token, y * np.float32(2))
Exemple #25
0
 def f(x):
   token = lax.create_token(x)
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   token = lax.outfeed(token, y + np.float32(1))
   return x - 1