def testInfeed(self, partition_input): if jax.local_device_count() % 2 != 0: raise SkipTest shape = (jax.local_device_count() * 2, 4) # Run computation across all devices so we know which devices to feed. parts = P(jax.local_device_count(), 1) in_parts = parts if partition_input else None infeed_shapes = (jax.ShapedArray(shape, np.float32), jax.ShapedArray((1,), np.float32)) infeed_parts = (parts, None) @partial(sharded_jit, in_parts=in_parts, out_parts=None) def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z[jnp.newaxis] x = np.arange(prod(shape), dtype=np.float32).reshape(shape) y = x + 1 shard_size = shape[0] // jax.local_device_count() y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)] z = jnp.array([3.], dtype=np.float32) assert len(jax.local_devices()) == len(y_shards) for device, y_shard in zip(jax.local_devices(), y_shards): device.transfer_to_infeed((y_shard, z)) # Transfer data to infeed before executing the function. For GPUs, the # execution of the compiled function is blocking, so transferring data # to infeed before executing ensures that the execution does not deadlock # waiting for the infeed data. result = f(x) expected = x @ y.T + z[jnp.newaxis] self.assertAllClose(result, expected, check_dtypes=False)
def f(x): token = lax.create_token(x) (y,), token = lax.infeed( token, shape=(jax.ShapedArray((3, 4), jnp.float32),)) (z,), _ = lax.infeed( token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),)) return x + y + z
def testInfeed(self, partition_input): if jax.local_device_count() % 2 != 0: raise SkipTest shape = (jax.local_device_count() * 2, 4) # Run computation across all devices so we know which devices to feed. parts = P(jax.local_device_count(), 1) in_parts = parts if partition_input else None infeed_shapes = (jax.ShapedArray(shape, np.float32), jax.ShapedArray((1, ), np.float32)) infeed_parts = (parts, None) @partial(sharded_jit, in_parts=in_parts, out_parts=None) def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z x = np.arange(prod(shape), dtype=np.float32).reshape(shape) y = x + 1 shard_size = shape[0] // jax.local_device_count() y_shards = [ y[i:i + shard_size] for i in range(0, shape[0], shard_size) ] z = jnp.array([3.], dtype=np.float32) result = f(x) assert len(jax.local_devices()) == len(y_shards) for device, y_shard in zip(jax.local_devices(), y_shards): device.transfer_to_infeed((y_shard, z)) expected = x @ y.T + z self.assertAllClose(result, expected, check_dtypes=False)
def host_loop_eval_step(model, state, metrics): token = lax.create_token(metrics['samples']) batch, token = lax.infeed( token, shape=(jax.ShapedArray(eval_input_shape, model_dtype), jax.ShapedArray((device_eval_batch_size, ), jnp.int32))) metrics = eval_step(model, state, batch, metrics, image_format, space_to_depth) return metrics
def f_for_jit(x): token = lax.create_token(x) (y,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) (z,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) (w,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) return x + y + z + w
def device_train_loop_body(args): optimizer, state, metrics, token, step, epoch = args (images, labels), token = lax.infeed( token, shape=(jax.ShapedArray(train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) batch = {'image': images, 'label': labels} optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn) step += 1 return optimizer, state, metrics, token, step, epoch
def host_loop_train_step(optimizer, state, metrics): token = lax.create_token(optimizer.state[0].step) batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) return optimizer, state, metrics
def device_train_loop_body(args): optimizer, state, metrics, token, step, loop = args batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) step += 1 return optimizer, state, metrics, token, step, loop
def scan_fn(broadcast_in, init, *args): xs = jax.tree_multimap(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_multimap( lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_multimap( lambda ax, y: (y if ax is broadcast else ()), out_axes, ys) return broadcast_out, ys else: ys = jax.tree_multimap( lambda ax, y: (() if ax is broadcast else y), out_axes, ys) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) carry_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init) scan_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs) input_pvals = (carry_pvals, scan_pvals) in_pvals, in_tree = jax.tree_flatten(input_pvals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat) c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse) ys = jax.tree_multimap(transpose_from_front, out_axes, ys) ys = jax.tree_multimap( lambda ax, const, y: (const if ax is broadcast else y), out_axes, constants_out, ys) return broadcast_in, c, ys
def device_train_loop_body(args): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape ])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch
def _replicate(x, devices=None): x = jax.numpy.array(x) if devices is None: devices = jax.local_devices() aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype) buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)
def partial_eval_by_shape(fn, input_spec, *args, **kwargs): """Lazily evaluate a function by using the shapes of the inputs. This function is similar to `jax.eval_shape` with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example `module.init_by_shape` where this functionality is used to initialize a model without using input data lr computation. Args: fn: the function to be lazily evaluated. input_spec: an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32. *args: other arguments passed to the module's apply function **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and an instance of Model """ # output cannot be returned in lazy_create because jax.eval_shape will only # return the shape and dtype. # TODO(mattjj,jheek): use a public JAX API f = lambda *inputs: fn(*inputs, *args, **kwargs) input_structs = [_parse_spec(spec) for spec in input_spec] inputs_flat, in_tree = jax.tree_flatten(input_structs) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree) in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)) for x in inputs_flat] if _is_omnistaging: _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) else: _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) for pv, const in out_pvals] return jax.tree_unflatten(out_tree(), out_flat)
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1 if config.omnistaging_enabled else lax.tie_in( token, x - 1)
def device_train_loop_body(args): """Device training loop body.""" (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token, step, epoch, num_steps_per_epoch) = args device_batch_size = FLAGS.train_batch_size // jax.device_count() input_shape = [device_batch_size, FLAGS.max_seq_length] input_shape_pred = [device_batch_size, FLAGS.max_predictions_per_seq] (input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels), token = lax.infeed( token, shape=(jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.float32), jax.ShapedArray([device_batch_size, 1], jnp.int32))) inputs = [input_ids, input_mask, segment_ids, masked_lm_positions] labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels] optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng = train_step( optimizer, inputs, labels, learning_rate_fn, dropout_rng=new_dropout_rng) step += 1 return (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token, step, epoch, num_steps_per_epoch)
def f_for_pjit(x): token = lax.create_token(x) # A replicated infeed (y, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(None, )) # An infeed sharded on first axis (z, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(nr_devices, 1), )) # An infeed sharded on second axis (w, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(1, nr_devices), )) return x + y + z + w
def testInfeedPytree(self): x = np.float32(1.5) y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) to_infeed = dict(a=x, b=y) to_infeed_shape = dict(a=jax.ShapedArray((), dtype=np.float32), b=jax.ShapedArray((3, 4), dtype=np.int16)) @jax.jit def f(x): token = lax.create_token(x) res, token = lax.infeed(token, shape=to_infeed_shape) return res device = jax.local_devices()[0] # We must transfer the flattened data, as a tuple!!! flat_to_infeed, _ = jax.tree_flatten(to_infeed) device.transfer_to_infeed(tuple(flat_to_infeed)) self.assertAllClose(f(x), to_infeed)
def _replicate(x, devices=None): x = jax.numpy.array(x) if devices is None: # match the default device assignments used in pmap: # for single-host, that's the XLA default device assignment # for multi-host, it's the order of jax.local_devices() if jax.host_count() == 1: devices = [d for d in xb.get_backend().get_default_device_assignment( jax.device_count()) if d.host_id == jax.host_id()] else: devices = jax.local_devices() aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)
def device_train_loop_body(args): optimizer, dropout_rngs, metrics, token, step, epoch = args input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray( device_train_input_shape, jnp.int32) for _ in train_keys ])) batch = {k: v for k, v in zip(train_keys, input_data)} optimizer, metrics, dropout_rngs = train_step(optimizer, batch, metrics, learning_rate_fn, dropout_rng=dropout_rngs) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch
def abstract_single_value(value): if isinstance(value, jnp.ndarray): value = jax.ShapedArray(np.shape(value), np.result_type(value)) return pe.PartialVal.unknown(value) else: return value
def _compute_out_shapes(self, ins, outs): """ Compute the shapes of outputs based on those of the inputs. Parameters ---------- ins : dict Dict of input metadata containing input shapes. outs : dict Dict of output metadata that will be updated with shape information. """ need_shape = [] for name, ometa in outs.items(): try: ometa['shape'] except KeyError: need_shape.append(name) args = [] static_argnums = [] for i, (name, meta) in enumerate(ins.items()): if 'is_option' in meta and meta['is_option']: if 'default' in meta: val = meta['default'] elif 'values' in meta: val = meta['values'][0] else: val = None args.append(val) static_argnums.append(i) continue if meta['val'] is not None: args.append(meta['val']) else: try: shp = meta['shape'] except KeyError: if 'resid' not in meta: # this is an input, not a state raise RuntimeError( f"Can't determine shape of input '{name}'.") else: if jax is not None: shp = None if shp is None else _shape2tuple(shp) args.append(jax.ShapedArray(shp, dtype=np.float64)) # compute shapes as a check against shapes in metadata (if any) if jax is not None: try: # must replace numpy with jax numpy when making jaxpr. with jax_context(self._f.__globals__): v = jax.make_jaxpr(self._f, static_argnums)(*args) except Exception as err: if need_shape: raise RuntimeError( f"Failed to determine the output shapes " f"based on the input shapes. The error was: {err}. To " "avoid this error, add return value metadata that " "specifies the shapes of the return values to the function." ) warnings.warn( "Failed to determine the output shapes based on the input " "shapes in order to check the provided metadata values. The" f" error was: {err}.") else: for val, name in zip(v.out_avals, outs): oldshape = outs[name].get('shape') if oldshape is not None and _shape2tuple( oldshape) != val.shape: raise RuntimeError( f"shape from metadata for return value " f"'{name}' of {oldshape} doesn't match computed " f"shape of {val.shape}.") outs[name]['shape'] = val.shape need_shape = [] if need_shape: # output shapes weren't provided by user or by jax shape = self._output_defaults['shape'] warnings.warn( f"Return values {need_shape} have unspecified shape so are assumed to " f"have shape {shape}.") for name in need_shape: outs[name]['shape'] = shape
def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs): # split rngs split_fn = lambda rng: random.split(rng, length) broadcast_rngs_xs = [] scan_rngs_xs = [] for rng_groups in rng_groups_xs: broadcast_rngs_xs.append( tuple(rng_group for rng_group, split in zip(rng_groups, rng_splits) if not split)) scan_rngs_xs.append( tuple( jax.tree_map(split_fn, rng_group) for rng_group, split in zip(rng_groups, rng_splits) if split)) def body(carry, xs, init_mode=False): carry_vars_xs, c = carry scan_vars_xs, scan_rngs_xs, x = xs variable_groups_xs = combine(scan_vars_xs, carry_vars_xs, broadcast_vars_xs) rng_groups_xs = [] for broadcast_rngs, scan_rngs in zip(broadcast_rngs_xs, scan_rngs_xs): rng_groups_xs.append(broadcast_rngs + scan_rngs) scope = scope_fn(variable_groups_xs, rng_groups_xs) carry, y = fn(scope, c, x) out_vars = repack_fn(scope) scan_vars_xs, carry_vars_out_xs, broadcast_vars_out_xs = split( out_vars, 1) # TODO(jheek) more informative error check def check_shapes(c_in, c_out): if not isinstance(c_in, jnp.ndarray) or not isinstance( c_out, jnp.ndarray): return if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype( c_in) != jnp.dtype(c_out): raise ValueError() try: jax.tree_multimap(check_shapes, carry_vars_xs, carry_vars_out_xs) except ValueError: raise ValueError( 'carry variables must have the same shape and dtype before and after scan.' ) if init_mode: return broadcast_vars_out_xs else: return (carry_vars_out_xs, carry), (scan_vars_xs, y) broadcast_body = functools.partial(body, init_mode=True) scan_vars_xs, carry_vars_xs, broadcast_vars_xs = split( variable_groups_xs, 0) carry0 = (carry_vars_xs, init_carry) xxs = (scan_vars_xs, scan_rngs_xs, xs) # use partial evaluation to find the variables that are broadcasted out # an error is thrown if a broadcasted output has a dependency on any scan variables carry_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)), carry0) scan_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(x.shape[1:], x.dtype)), xxs) input_pvals = (carry_pvals, scan_pvals) in_pvals, in_tree = jax.tree_flatten(input_pvals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) # _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) (carry_vars_xs, carry), (scan_vars_xs, ys) = lax.scan(body, carry0, xxs, length=length, reverse=reverse) broadcast_vars_xs = jax.tree_unflatten(out_tree(), out_flat) out_vars_xs = combine(carry_vars_xs, scan_vars_xs, broadcast_vars_xs) return (carry, ys), out_vars_xs
def _replicate(x): """Replicate an object on each device.""" x = jnp.array(x) aval = jax.ShapedArray((len(devices), ) + x.shape, x.dtype) buffers = [jax.interpreters.xla.device_put(x, d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), np.float32)) token = lax.outfeed(token, y + onp.float32(1)) return lax.tie_in(token, x - 1)
def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2))
def f(x): token = lax.create_token(x) y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1