def f(x): token = lax.create_token(x) (y,), token = lax.infeed( token, shape=(jax.ShapedArray((3, 4), jnp.float32),)) (z,), _ = lax.infeed( token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),)) return x + y + z
def f_for_jit(x): token = lax.create_token(x) (y,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) (z,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) (w,), token = lax.infeed( token, shape=(jax.ShapedArray(x.shape, np.float32),)) return x + y + z + w
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1 if config.omnistaging_enabled else lax.tie_in( token, x - 1)
def device_train_loop_body(args): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape ])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch
def device_train_loop_body(args): """Device training loop body.""" (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token, step, epoch, num_steps_per_epoch) = args device_batch_size = FLAGS.train_batch_size // jax.device_count() input_shape = [device_batch_size, FLAGS.max_seq_length] input_shape_pred = [device_batch_size, FLAGS.max_predictions_per_seq] (input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels), token = lax.infeed( token, shape=(jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.int32), jax.ShapedArray(input_shape_pred, jnp.float32), jax.ShapedArray([device_batch_size, 1], jnp.int32))) inputs = [input_ids, input_mask, segment_ids, masked_lm_positions] labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels] optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng = train_step( optimizer, inputs, labels, learning_rate_fn, dropout_rng=new_dropout_rng) step += 1 return (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token, step, epoch, num_steps_per_epoch)
def f_for_pjit(x): token = lax.create_token(x) # A replicated infeed (y, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(None, )) # An infeed sharded on first axis (z, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(nr_devices, 1), )) # An infeed sharded on second axis (w, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(1, nr_devices), )) return x + y + z + w
def host_loop_eval_step(model, state, metrics): token = lax.create_token(metrics['samples']) batch, token = lax.infeed( token, shape=(jax.ShapedArray(eval_input_shape, model_dtype), jax.ShapedArray((device_eval_batch_size, ), jnp.int32))) metrics = eval_step(model, state, batch, metrics, image_format, space_to_depth) return metrics
def device_train_loop_body(args): optimizer, state, metrics, token, step, epoch = args (images, labels), token = lax.infeed( token, shape=(jax.ShapedArray(train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) batch = {'image': images, 'label': labels} optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn) step += 1 return optimizer, state, metrics, token, step, epoch
def host_loop_train_step(optimizer, state, metrics): token = lax.create_token(optimizer.state[0].step) batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) return optimizer, state, metrics
def device_train_loop_body(args): optimizer, state, metrics, token, step, loop = args batch, token = lax.infeed(token, shape=(jax.ShapedArray( train_input_shape, model_dtype), jax.ShapedArray((device_batch_size, ), jnp.int32))) optimizer, state, metrics = train_step(optimizer, state, batch, metrics, learning_rate_fn, image_format, space_to_depth) step += 1 return optimizer, state, metrics, token, step, loop
def device_train_loop_body(args): optimizer, dropout_rngs, metrics, token, step, epoch = args input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray( device_train_input_shape, jnp.int32) for _ in train_keys ])) batch = {k: v for k, v in zip(train_keys, input_data)} optimizer, metrics, dropout_rngs = train_step(optimizer, batch, metrics, learning_rate_fn, dropout_rng=dropout_rngs) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch
def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z
def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2))
def f(x): token = lax.create_token(x) y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1
def f(x): token = lax.create_token(x) res, token = lax.infeed(token, shape=to_infeed_shape) return res
def f(x): token = lax.create_token(x) y, token = lax.infeed(token, shape=jax.ShapedArray((3, 4), np.float32)) token = lax.outfeed(token, y + onp.float32(1)) return lax.tie_in(token, x - 1)