예제 #1
0
 def f(x):
   token = lax.create_token(x)
   (y,), token = lax.infeed(
       token, shape=(jax.ShapedArray((3, 4), jnp.float32),))
   (z,), _ = lax.infeed(
       token, shape=(jax.ShapedArray((3, 1, 1), jnp.float32),))
   return x + y + z
예제 #2
0
파일: pjit_test.py 프로젝트: rsepassi/jax
    def f_for_jit(x):
      token = lax.create_token(x)
      (y,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))
      (z,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))
      (w,), token = lax.infeed(
          token, shape=(jax.ShapedArray(x.shape, np.float32),))

      return x + y + z + w
예제 #3
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), jnp.float32))
     token = lax.outfeed(token, y + np.float32(1))
     return x - 1 if config.omnistaging_enabled else lax.tie_in(
         token, x - 1)
예제 #4
0
 def device_train_loop_body(args):
     """On-device loop body."""
     optimizer, dropout_rngs, metrics, token, step, epoch = args
     # Ordering input data from infeed requires threading a symbolic token
     # through the computation.
     input_data, token = lax.infeed(token,
                                    shape=tuple([
                                        jax.ShapedArray(s, jnp.int32)
                                        for s in device_train_input_shape
                                    ]))
     # Rebuild input dict from infeed data tuple.
     batch = {k: v for k, v in zip(train_keys, input_data)}
     # Run the train_step function and return the loop state.
     optimizer, metrics, dropout_rngs = train_lib.train_step(
         optimizer,
         batch,
         metrics,
         dropout_rngs,
         train_config,
         learning_rate_fn,
         num_microbatches=CFG.microbatches,
         label_smoothing=CFG.label_smoothing,
         z_loss=CFG.z_loss)
     step += 1
     return optimizer, dropout_rngs, metrics, token, step, epoch
예제 #5
0
 def device_train_loop_body(args):
   """Device training loop body."""
   (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token,
    step, epoch, num_steps_per_epoch) = args
   device_batch_size = FLAGS.train_batch_size // jax.device_count()
   input_shape = [device_batch_size, FLAGS.max_seq_length]
   input_shape_pred = [device_batch_size, FLAGS.max_predictions_per_seq]
   (input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids,
    masked_lm_weights, next_sentence_labels), token = lax.infeed(
        token,
        shape=(jax.ShapedArray(input_shape, jnp.int32),
               jax.ShapedArray(input_shape, jnp.int32),
               jax.ShapedArray(input_shape, jnp.int32),
               jax.ShapedArray(input_shape_pred, jnp.int32),
               jax.ShapedArray(input_shape_pred, jnp.int32),
               jax.ShapedArray(input_shape_pred, jnp.float32),
               jax.ShapedArray([device_batch_size, 1], jnp.int32)))
   inputs = [input_ids, input_mask, segment_ids, masked_lm_positions]
   labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels]
   optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng = train_step(
       optimizer,
       inputs,
       labels,
       learning_rate_fn,
       dropout_rng=new_dropout_rng)
   step += 1
   return (optimizer, total_loss, lm_loss, sentence_loss,
           new_dropout_rng, token, step, epoch, num_steps_per_epoch)
예제 #6
0
 def f_for_pjit(x):
     token = lax.create_token(x)
     # A replicated infeed
     (y, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(None, ))
     # An infeed sharded on first axis
     (z, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(P(nr_devices, 1), ))
     # An infeed sharded on second axis
     (w, ), token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   x.shape, np.float32), ),
                               partitions=(P(1, nr_devices), ))
     return x + y + z + w
예제 #7
0
 def host_loop_eval_step(model, state, metrics):
     token = lax.create_token(metrics['samples'])
     batch, token = lax.infeed(
         token,
         shape=(jax.ShapedArray(eval_input_shape, model_dtype),
                jax.ShapedArray((device_eval_batch_size, ), jnp.int32)))
     metrics = eval_step(model, state, batch, metrics, image_format,
                         space_to_depth)
     return metrics
예제 #8
0
 def device_train_loop_body(args):
     optimizer, state, metrics, token, step, epoch = args
     (images, labels), token = lax.infeed(
         token,
         shape=(jax.ShapedArray(train_input_shape, model_dtype),
                jax.ShapedArray((device_batch_size, ), jnp.int32)))
     batch = {'image': images, 'label': labels}
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn)
     step += 1
     return optimizer, state, metrics, token, step, epoch
예제 #9
0
 def host_loop_train_step(optimizer, state, metrics):
     token = lax.create_token(optimizer.state[0].step)
     batch, token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   train_input_shape, model_dtype),
                                      jax.ShapedArray((device_batch_size, ),
                                                      jnp.int32)))
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn,
                                            image_format, space_to_depth)
     return optimizer, state, metrics
예제 #10
0
 def device_train_loop_body(args):
     optimizer, state, metrics, token, step, loop = args
     batch, token = lax.infeed(token,
                               shape=(jax.ShapedArray(
                                   train_input_shape, model_dtype),
                                      jax.ShapedArray((device_batch_size, ),
                                                      jnp.int32)))
     optimizer, state, metrics = train_step(optimizer, state, batch,
                                            metrics, learning_rate_fn,
                                            image_format, space_to_depth)
     step += 1
     return optimizer, state, metrics, token, step, loop
예제 #11
0
 def device_train_loop_body(args):
     optimizer, dropout_rngs, metrics, token, step, epoch = args
     input_data, token = lax.infeed(token,
                                    shape=tuple([
                                        jax.ShapedArray(
                                            device_train_input_shape,
                                            jnp.int32) for _ in train_keys
                                    ]))
     batch = {k: v for k, v in zip(train_keys, input_data)}
     optimizer, metrics, dropout_rngs = train_step(optimizer,
                                                   batch,
                                                   metrics,
                                                   learning_rate_fn,
                                                   dropout_rng=dropout_rngs)
     step += 1
     return optimizer, dropout_rngs, metrics, token, step, epoch
예제 #12
0
 def f(x):
     token = lax.create_token(x)
     (y, z), token = lax.infeed(token,
                                infeed_shapes,
                                partitions=infeed_parts)
     return x @ y.T + z
예제 #13
0
 def doubler(_, token):
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   return lax.outfeed(token, y * np.float32(2))
예제 #14
0
 def f(x):
   token = lax.create_token(x)
   y, token = lax.infeed(
       token, shape=jax.ShapedArray((3, 4), jnp.float32))
   token = lax.outfeed(token, y + np.float32(1))
   return x - 1
예제 #15
0
 def f(x):
   token = lax.create_token(x)
   res, token = lax.infeed(token, shape=to_infeed_shape)
   return res
예제 #16
0
 def f(x):
     token = lax.create_token(x)
     y, token = lax.infeed(token,
                           shape=jax.ShapedArray((3, 4), np.float32))
     token = lax.outfeed(token, y + onp.float32(1))
     return lax.tie_in(token, x - 1)