Exemple #1
0
    def testPyTreeArgs(self):
        if jax.device_count() < 2:
            raise SkipTest

        def f(a, b, c):
            a1, a2 = a
            c1, (c2, c3) = c
            return a1 + a2 + b + c1 + c2 + c3

        def _make_arg(*shape):
            return np.arange(prod(shape)).reshape(shape)

        a = (_make_arg(4, 4), 1)
        b = _make_arg(4, 4)
        c = [2, (_make_arg(4, 4), _make_arg(4, 4))]

        in_parts = (None, P(2, 1), [None, P(2, 1)])
        out_parts = P(2, 1)

        result = sharded_jit(f, in_parts, out_parts)(a, b, c)
        expected = f(a, b, c)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertIsInstance(result, pxla.ShardedDeviceArray)
        self.assertLen(result.device_buffers, 2)

        in_parts = None
        result = sharded_jit(f, in_parts, out_parts)(a, b, c)
        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertIsInstance(result, pxla.ShardedDeviceArray)
        self.assertLen(result.device_buffers, 2)
Exemple #2
0
    def testShardingConstraint(self):
        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(np.prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertEqual(actual.device_buffers[0].shape().dimensions(), (4, 8))
        self.assertEqual(actual.device_buffers[1].shape().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                ValueError,
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertAllClose(actual.device_buffers[0].to_py(),
                            actual.device_buffers[1].to_py(),
                            check_dtypes=False)
Exemple #3
0
    def test_sharded_jit_with_sharding_constraint(self):
        """A sharding constraint in the middle."""
        def jax_func(x, y):
            logits1 = jnp.dot(x, y)
            return jnp.sin(
                sharded_jit.with_sharding_constraint(logits1, P(2, 1)))

        sharded_jax_func = sharded_jit.sharded_jit(jax_func,
                                                   in_parts=(P(1, 2), P(2, 1)),
                                                   out_parts=P(1, 2))
        xshape = (6, 8)
        x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape)
        yshape = (8, 10)
        y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape)
        self._check_sharding_annotations(
            sharded_jax_func,
            [x, y],
            expected=[
                r"f32\[6,8\].*sharding={devices=\[1,2\]",
                r"f32\[8,10\].*sharding={devices=\[2,1\]",
                r"f32\[6,10\].*sharding={devices=\[2,1\]",
                r"f32\[6,10\].*sine.*sharding={devices=\[1,2\]"
            ],
            expected_opt=[
                # TODO(necula): relax ordering
                r"f32\[4,10\].*sharding={devices=\[2,1\]",
                r"f32\[6,4\].*sharding={devices=\[1,2\]",
            ],
            num_partitions=2)
Exemple #4
0
    def testPyTreeArgs(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        def f(a, b, c):
            a1, a2 = a
            c1, (c2, c3) = c
            return a1 + a2 + b + c1 + c2 + c3

        def _make_arg(*shape):
            return np.arange(prod(shape)).reshape(shape)

        a = (_make_arg(2, 4, 4), _make_arg(2))
        b = _make_arg(2, 4, 4)
        c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4)))

        in_parts = (None, P(2, 1), (None, P(2, 1)))
        out_parts = P(2, 1)

        result = pmap(sharded_jit(f, in_parts=in_parts,
                                  out_parts=out_parts))(a, b, c)
        expected = pmap(f)(a, b, c)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
        self.assertEqual(len(result.device_buffers), 4)
Exemple #5
0
 def test_sharded_jit_in_out(self):
     """Test input and output sharding annotations."""
     sharded_jax_func = sharded_jit.sharded_jit(jnp.dot,
                                                in_parts=(P(1, 2), P(2, 1)),
                                                out_parts=P(1, 2))
     xshape = (3, 8)
     x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape)
     yshape = (8, 5)
     y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape)
     self._check_sharding_annotations(
         sharded_jax_func,
         [x, y],
         expected=[
             r"f32\[3,8\].*sharding={devices=\[1,2\]",
             r"f32\[8,5\].*sharding={devices=\[2,1\]",
             r"f32\[3,5\].*sharding={devices=\[1,2\]"
         ],
         expected_opt=[
             # TODO(necula): relax ordering
             r"f32\[4,5\].*sharding={devices=\[2,1\]",
             r"f32\[3,4\].*sharding={devices=\[1,2\]",
             r"f32\[3,5\].*fusion",
             r"f32\[3,5\].*all-reduce",
         ],
         num_partitions=2)
Exemple #6
0
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.assert_num_jit_and_pmap_compilations(1):
            sharded_f(x)
            sharded_f(x)
Exemple #7
0
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.count_jit_and_pmap_compiles() as count:
            sharded_f(x)
            sharded_f(x)
        self.assertEqual(count[0], 1)
Exemple #8
0
    def testPyTreeOutputs(self):
        if jax.device_count() < 2:
            raise SkipTest

        def f(x):
            return x + 1, ((x + 2, x + 3), x + 4)

        shape = (4, 4)
        x = np.arange(prod(shape)).reshape(shape)
        in_parts = (P(2, 1), )
        out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

        result = sharded_jit(f, in_parts, out_parts)(x)
        expected = f(x)
        self.assertAllClose(result, expected, check_dtypes=False)

        out_parts = None
        result = sharded_jit(f, in_parts, out_parts)(x)
        self.assertAllClose(result, expected, check_dtypes=False)
Exemple #9
0
    def testShardingConstraint(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")

        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
        # the default.
        self.assertEqual(
            getattr(actual.device_buffers[0], "xla_shape",
                    actual.device_buffers[0].shape)().dimensions(), (4, 8))
        self.assertEqual(
            getattr(actual.device_buffers[1], "xla_shape",
                    actual.device_buffers[1].shape)().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                ValueError,
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertAllClose(actual.device_buffers[0].to_py(),
                            actual.device_buffers[1].to_py(),
                            check_dtypes=False)
Exemple #10
0
  def testPyTreeOutputs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (2, 4, 4)
    x = np.arange(np.prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x)
    expected = pmap(f)(x)

    self.assertAllClose(result, expected, check_dtypes=False)
Exemple #11
0
    def testNestedShardingConstraint(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")

        shape = (8, 8)

        @jit
        def f(x):
            return lax.while_loop(
                lambda i: i[0, 0] < 10.,
                lambda i: with_sharding_constraint(i + 1., P(2, 1)), x)

        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
        expected = x + 10.
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
Exemple #12
0
  def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32):
    """Compares pmap(sharded_jit(f, ...)) to pmap(f)"""
    shape = (2, 4, 4)
    num_shards = shape[0] * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    x = np.arange(np.prod(shape, dtype=dtype)).reshape(shape)
    y = x + 1
    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y)
    expected = pmap(f)(x, y)
    self.assertAllClose(result, expected, check_dtypes=False)

    flat_result = tree_util.tree_flatten(result)[0]
    for r in flat_result:
      self.assertTrue(isinstance(r, pxla.ShardedDeviceArray))
      self.assertEqual(len(r.device_buffers), num_shards)
Exemple #13
0
  def testManyArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    num_args = 200

    def f(*args):
      return jnp.sum(args)

    shape = (2, 4, 4)
    args = [np.arange(np.prod(shape)).reshape(shape)] * num_args
    in_partitions = (P(2, 1),) * num_args
    out_partitions = None
    result = pmap(sharded_jit(
        f, in_parts=in_partitions, out_parts=out_partitions))(*args)
    expected = pmap(f)(*args)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)
Exemple #14
0
  def testInAxesNone(self):
    shape = (4, 4)
    replicas = 2
    in_partitions = (P(2, 1), None, None)
    out_partitions = P(2, 1)
    in_axes = (None, None, 0)
    x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
    dummy = np.arange(replicas, dtype=np.float32) + 1
    num_shards = replicas * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    def f(x, y, _):
      return x @ y

    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions),
        in_axes=in_axes)(x, y, dummy)
    expected = pmap(f, in_axes=in_axes)(x, y, dummy)
    self.assertAllClose(result, expected, check_dtypes=True)
Exemple #15
0
  def test_sharded_jit_replicated(self):
    """A replicated input and output."""

    sharded_jax_func = sharded_jit.sharded_jit(
        jnp.dot, in_parts=(P(1, 2), None), out_parts=None)
    xshape = (3, 8)
    x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape)
    yshape = (8, 5)
    y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape)
    self._check_sharding_annotations(
        sharded_jax_func, [x, y],
        expected=[
            r"f32\[3,8\].*sharding={devices=\[1,2\]",
            r"f32\[8,5\].*sharding={replicated}",
            r"f32\[3,5\].*sharding={replicated}"
        ],
        expected_opt=[
            # TODO(necula): relax ordering
            r"f32\[8,5\].*sharding={replicated}",
            r"f32\[3,4\].*sharding={devices=\[1,2\]",
        ],
        num_partitions=2)
Exemple #16
0
def main(argv):
    global BLEU_THRESHOLD_REACHED
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    init_mllogger()
    mllogger.event('cache_clear')
    mllogger.start('init_start')
    mllogger.event('submission_org', 'Google')
    mllogger.event('submission_platform',
                   'TPUv3-{}'.format(jax.device_count()))
    mllogger.event('submission_division', 'closed')
    mllogger.event('submission_status', 'research')
    mllogger.event('submission_benchmark', 'transformer')
    mllogger.event('train_samples', input_pipeline.N_TRAIN)
    mllogger.event('eval_samples', input_pipeline.N_EVAL)

    tf.enable_v2_behavior()

    # Use hardware RNG for bernoulli randoms in dropout mask creation.
    if FLAGS.hardware_rng:
        models.set_hardware_bernoulli()

    num_partitions = FLAGS.num_partitions
    batch_size = FLAGS.batch_size
    if batch_size is None:
        batch_size = min(16 * jax.device_count() // num_partitions, 2048)
    mllogger.event('global_batch_size', batch_size)

    num_eval_steps = FLAGS.num_eval_steps
    max_target_length = FLAGS.max_target_length
    max_eval_target_length = FLAGS.max_eval_target_length
    max_length = max(max_target_length, max_eval_target_length)
    mllogger.event('max_sequence_length',
                   max_length,
                   metadata={'method': 'discard'})
    if FLAGS.random_seed is not None:
        seed = FLAGS.random_seed
    else:
        seed = np.uint32(time.time() if jax.host_id() == 0 else 0)
        seed = per_host_sum_pmap(seed)
    mllogger.event('seed', int(seed))
    steps_per_epoch = int(math.ceil(input_pipeline.N_TRAIN / batch_size))
    logging.info('steps per epoch: %d', steps_per_epoch)
    num_replicas = jax.local_device_count() // num_partitions
    device_train_input_shape = (batch_size //
                                (num_replicas * jax.host_count()),
                                max_target_length)
    # This is per-host; in principle 64/replica or more should fit
    eval_batch_size = min(
        32 * num_replicas,
        int(
            math.ceil(input_pipeline.N_EVAL /
                      (num_replicas * jax.host_count()))) * num_replicas)
    logging.info('eval batch size: %d', eval_batch_size)
    pred_batches = int(
        math.ceil(input_pipeline.N_EVAL /
                  (jax.host_count() * eval_batch_size)))
    logging.info('pred batches: %d', pred_batches)
    broadcast = functools.partial(_broadcast,
                                  num_replicas=num_replicas,
                                  num_partitions=num_partitions)

    if jax.host_id() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))
    else:
        train_summary_writer = None
        eval_summary_writer = None
    # Write summaries in background thread to avoid blocking on device sync
    summary_thread = thread.ThreadPoolExecutor(1, 'summary')
    if FLAGS.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    # MLPerf 2020 WMT en-de dataset uses a custom T2T dataset:
    #   Shared 32K subword tokenization
    #   256-length packed training examples from WMT17
    #   97-length unpacked evaluation examples from WMT14
    train_keys = [
        'inputs', 'targets', 'inputs_position', 'targets_position',
        'inputs_segmentation', 'targets_segmentation'
    ]
    encoder = mlperf_encoder.SubwordTextEncoder(filename=FLAGS.vocab_path)
    input_encoder = encoder
    target_encoder = encoder
    vocab_size = input_encoder.vocab_size
    output_vocab_size = target_encoder.vocab_size

    input_shape = (batch_size, max_target_length)
    target_shape = (batch_size, max_target_length)

    transformer_kwargs = flax.core.FrozenDict({
        'vocab_size': vocab_size,
        'output_vocab_size': output_vocab_size,
        'emb_dim': 1024,
        'num_heads': 16,
        'num_layers': 6,
        'qkv_dim': 1024,
        'mlp_dim': 4096,
        'max_len': max_length,
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
        'num_partitions': num_partitions,
    })

    rng = random.PRNGKey(seed)
    rng, init_rng = random.split(rng)
    model, cache_def = create_model(init_rng, tuple(input_shape),
                                    tuple(target_shape), transformer_kwargs)
    mllogger.event('opt_name', 'adam')
    if batch_size < 1024:
        learning_rate = 4.0  # 0.0625
        warmup_steps = 1000
        beta1 = 0.9
        beta2 = 0.98
    if batch_size < 2048:
        learning_rate = 2.0
        warmup_steps = 500  # ??
        beta1 = 0.9  # ??
        beta2 = 0.98  # ??
    else:
        learning_rate = 3.3092157691415953
        warmup_steps = 664
        beta1 = 0.9086575725261137
        beta2 = 0.9198719118104947
    epsilon = 1e-9
    if FLAGS.learning_rate is not None:
        learning_rate = FLAGS.learning_rate
    mllogger.event('opt_adam_beta_1', beta1)
    mllogger.event('opt_adam_beta_2', beta2)
    mllogger.event('opt_adam_epsilon', epsilon)
    optimizer_def = optim.Adam(learning_rate,
                               beta1=beta1,
                               beta2=beta2,
                               eps=epsilon,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(model)
    del model  # don't keep a copy of the initial model

    # Build parameter partition annotations for preserving partitions from train
    # to eval.
    partition_rules = [
        (('encoder', 'posembed_input'), partitions.empty_dict),
        (('decoder', 'posembed_targets'), partitions.empty_dict),
        (('embedding', ), partitions.spec(num_partitions, 1)),
        ((r'LayerNorm_\d+', '(bias|scale)'), None),
        ((r'encoder(decoder)?_norm', '(bias|scale)'), None),
        ((r'MultiHeadDotProductAttention_\d+', '(query|key|value)', 'kernel'),
         partitions.spec(1, num_partitions, 1)),
        ((r'MultiHeadDotProductAttention_\d+', 'out', 'kernel'),
         partitions.spec(num_partitions, 1, 1)),
        ((r'MlpBlock_\d+', r'Dense_\d+', 'bias'), None),
        ((r'MlpBlock_\d+', 'Dense_0', 'kernel'),
         partitions.spec(1, num_partitions)),
        ((r'MlpBlock_\d+', 'Dense_1', 'kernel'),
         partitions.spec(num_partitions, 1)),
        (('state', 'step'), None),
    ]
    optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(partition_rules, optimizer.state_dict()))

    optimizer = broadcast(optimizer)
    empty_metrics = broadcast({'loss': 0.0, 'accuracy': 0, 'denominator': 0})

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        hidden_size=transformer_kwargs['qkv_dim'])

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch',
                            in_axes=(None, 0, 0, 0))
    if num_partitions > 1:
        sharded_predict_step = sharded_jit(
            predict_step,
            in_parts=(None, optimizer_partitions.target, None),
            out_parts=None)
    else:
        sharded_predict_step = predict_step
    if FLAGS.extra_eval_metrics:
        p_eval_step = jax.pmap(eval_step, axis_name='batch', in_axes=(None, 0))
    p_pred_step = jax.pmap(sharded_predict_step,
                           axis_name='batch',
                           in_axes=(0, None, None))
    p_allreduce_metrics = jax.pmap(functools.partial(lax.psum,
                                                     axis_name='batch'),
                                   axis_name='batch')

    def device_train_loop_cond(args):
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        optimizer, dropout_rngs, metrics, token, step, epoch = args
        input_data, token = lax.infeed(token,
                                       shape=tuple([
                                           jax.ShapedArray(
                                               device_train_input_shape,
                                               jnp.int32) for _ in train_keys
                                       ]))
        batch = {k: v for k, v in zip(train_keys, input_data)}
        optimizer, metrics, dropout_rngs = train_step(optimizer,
                                                      batch,
                                                      metrics,
                                                      learning_rate_fn,
                                                      dropout_rng=dropout_rngs)
        step += 1
        return optimizer, dropout_rngs, metrics, token, step, epoch

    def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
        token = lax.create_token(step)
        optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, dropout_rngs, metrics, token, step, epoch))
        return optimizer, dropout_rngs, metrics, step

    if num_partitions > 1:
        device_train_loop = sharded_jit(device_train_loop,
                                        in_parts=(optimizer_partitions, None,
                                                  None, None, None),
                                        out_parts=(optimizer_partitions, None,
                                                   None, None))
    p_train_epoch = jax.pmap(device_train_loop,
                             axis_name='batch',
                             in_axes=(None, 0, 0, None, None))

    p_allreduce_metrics_train = functools.partial(lax.psum, axis_name='batch')
    if num_partitions > 1:
        p_allreduce_metrics_train = sharded_jit(p_allreduce_metrics_train,
                                                in_parts=None,
                                                out_parts=None,
                                                num_partitions=num_partitions)
    p_allreduce_metrics_train = jax.pmap(p_allreduce_metrics_train,
                                         axis_name='batch')

    # Precompile all needed computations with fake data so as not to include
    # compilation time in MLPerf metrics.
    if FLAGS.precompile:
        logging.info('precompiling step/epoch functions')
        if FLAGS.infeed:
            # the device training loop condition will immediately be false, but
            # the optimizer tree will be resharded here
            optimizer, *_ = p_train_epoch(unbroadcast(optimizer),
                                          random.split(rng, num_replicas),
                                          empty_metrics,
                                          jnp.array(0, dtype=jnp.int32), 1)
        else:
            metrics = empty_metrics
            train_input_shape = (num_replicas, batch_size // num_replicas,
                                 input_pipeline.MAX_TRAIN_LEN)
            fake_batch = {
                k: jnp.ones(train_input_shape, jnp.int32)
                for k in train_keys
            }
            p_train_step(unbroadcast(optimizer),
                         fake_batch,
                         metrics,
                         dropout_rng=random.split(rng, num_replicas))
        eval_input_shape = (num_replicas, eval_batch_size // num_replicas,
                            input_pipeline.MAX_EVAL_LEN)
        fake_eval_batch = {
            'inputs': jnp.ones(eval_input_shape, jnp.int32),
            'targets': jnp.ones(eval_input_shape, jnp.int32),
        }
        if FLAGS.extra_eval_metrics:
            p_eval_step(unbroadcast(optimizer.target), fake_eval_batch)
        fake_cache = cache_def.initialize_cache(
            (eval_input_shape[1], FLAGS.max_predict_length))
        p_pred_step(fake_eval_batch['inputs'], unbroadcast(optimizer.target),
                    fake_cache)
        time.sleep(20)
        sync_devices()
        fake_bleu_1 = np.zeros((4, ), dtype=np.int32)
        fake_bleu_2 = np.zeros((), dtype=np.int32)
        per_host_sum_pmap((fake_bleu_1, fake_bleu_1, fake_bleu_2, fake_bleu_2))
        sync_devices()
        p_allreduce_metrics_train(empty_metrics)
        sync_devices()
        logging.info('finished precompiling step/epoch functions')

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, num_replicas)

    # Record time-0 metrics for proper tensorboard plot x-axis scaling.
    if jax.host_id() == 0:
        if FLAGS.compute_train_metrics:
            train_summary_writer.scalar('loss', 9.999, 0)
            train_summary_writer.scalar('accuracy', 0.0, 0)
            train_summary_writer.flush()
        eval_summary_writer.scalar('bleu', 0.0, 0)
        eval_summary_writer.flush()

    train_ds = input_pipeline.get_wmt_dataset(batch_size=batch_size //
                                              jax.host_count(),
                                              train=True)
    eval_ds = input_pipeline.get_wmt_dataset(batch_size=eval_batch_size,
                                             train=False)
    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)
    local_devices = jax.local_devices()
    host_step, device_step = 0, broadcast(0)
    gc.disable()
    mllogger.end('init_stop')
    if jax.host_id() == 0:
        mllogger.start('run_start')
    for epoch in range(FLAGS.num_epochs):
        if jax.host_id() == 0 and not BLEU_THRESHOLD_REACHED:
            mllogger.start('block_start',
                           metadata={
                               'first_epoch_num': epoch + 1,
                               'epoch_count': 1
                           })
        metrics = empty_metrics
        if FLAGS.infeed:
            optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
                unbroadcast(optimizer), dropout_rngs, metrics,
                unbroadcast(device_step), epoch)
        while int(host_step // steps_per_epoch) == epoch:
            # pylint: disable=protected-access
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))
            # Shard data to devices and do a training step.
            batch = jax.tree_map(
                lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), batch)
            if FLAGS.infeed:
                for i, device in enumerate(local_devices):
                    replica_id = i // num_partitions
                    input_tuple = tuple(
                        [batch[k][replica_id] for k in train_keys])
                    assert input_tuple[0].shape == device_train_input_shape, (
                        'infeed shape error %s != %s' %
                        (input_tuple[0].shape, device_train_input_shape))
                    assert input_tuple[0].dtype == jnp.int32, (
                        'infeed dtype error %s != %s' %
                        (input_tuple[0].dtype, jnp.int32))
                    infeed_pool.submit(
                        functools.partial(device.transfer_to_infeed,
                                          input_tuple))
            else:
                optimizer, metrics, dropout_rngs = p_train_step(
                    unbroadcast(optimizer),
                    batch,
                    metrics,
                    dropout_rng=dropout_rngs)
            host_step += 1

        if FLAGS.compute_train_metrics:
            metrics = p_allreduce_metrics_train(metrics)
            # Schedule training metric handling.
            summary_thread.submit(
                functools.partial(write_train_summary, metrics,
                                  train_summary_writer, host_step))

        # Optional, extra evaluation metrics.
        if FLAGS.extra_eval_metrics:
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(num_eval_steps), eval_iter):
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(unbroadcast(optimizer.target),
                                      eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = p_allreduce_metrics(eval_metrics)
            # Schedule metric summarization/logging.
            summary_thread.submit(
                functools.partial(write_eval_summary, eval_metrics,
                                  eval_summary_writer, host_step))

        # Translation and BLEU Score.
        all_predicted, all_targets, all_bs = [], [], []
        for i in range(pred_batches):
            # pylint: disable=protected-access
            pred_batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter))
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size != eval_batch_size:
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, eval_batch_size), pred_batch)
            pred_batch = jax.tree_map(
                lambda x: x.reshape((num_replicas, -1) + x.shape[1:]),
                pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache = cache_def.initialize_cache(
                (per_device_batchsize, FLAGS.max_predict_length))
            all_predicted.append(
                p_pred_step(pred_batch['inputs'],
                            unbroadcast(optimizer.target), cache))
            all_targets.append(pred_batch['targets'])
            all_bs.append(cur_pred_batch_size)
        # Schedule BLEU calculation and summarization/logging.
        # We use the ICI as part of BLEU score computation, so we call this from the
        # main thread so the BLEU pmap runs before the next train epoch pmap
        write_predict_summary(all_predicted, all_targets, all_bs,
                              target_encoder, eval_summary_writer, epoch,
                              host_step, summary_thread)

    # Wait until computations are done before exiting
    sync_devices()
    if jax.host_id() == 0:
        summary_thread.shutdown()
        if not BLEU_THRESHOLD_REACHED:
            mllogger.end('run_stop', metadata={'status': 'aborted'})
Exemple #17
0
def main(argv):
    global CFG
    CFG = FLAGS.config

    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16.
    _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16))

    # Use hardware RNG for bernoulli randoms in dropout mask creation.
    if CFG.hardware_rng:
        models.set_hardware_bernoulli()

    if 'module_import' in CFG and CFG.module_import:
        for module in CFG.module_import:
            importlib.import_module(module)

    if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs:
        t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs)

    num_partitions = CFG.num_partitions
    topology = train_lib.compute_multihost_topology(num_partitions)
    batch_size = CFG.batch_size
    eval_batch_size = CFG.eval_batch_size
    per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets
    if batch_size % topology.num_replicas:
        raise ValueError(
            'Batch size must be divisible by the number of replicas.')

    steps_per_epoch = CFG.steps_per_epoch
    logging.info('steps per epoch: %d', steps_per_epoch)

    broadcast = functools.partial(
        train_lib.broadcast,
        num_replicas=topology.per_replica_set_num_replicas,
        num_partitions=topology.per_host_num_partitions,
        devices=topology.this_host_device_assignment)

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        tf.io.gfile.copy(FLAGS['config'].config_filename,
                         os.path.join(FLAGS.model_dir, 'config.py'),
                         overwrite=True)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    # Write summaries in background thread to avoid blocking on device sync
    if CFG.infeed:
        # Infeed is currently synchronous, so do it in a background thread too
        infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(),
                                                'infeed')

    (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache(
        CFG, topology.num_replica_sets, topology.replica_set_id,
        topology.per_replica_set_host_id)

    vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name)
    encoder = vocab.tf_tokenizer
    eos_id = vocab.tokenizer.eos_id()

    def decode_tokens(toks, eos_id=eos_id, max_id=32000):
        """Decode tokens back to unicode."""
        del eos_id
        # TODO(levskaya): T5 doesn't seem to emit EOS tokens?  double check this
        # is the best decoding function or just switch to using tf_decode.
        # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        valid_toks = toks.astype(np.int32)
        valid_toks[valid_toks >= max_id] = 3
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    train_config, eval_config, predict_config = get_configs(CFG)

    rng = random.PRNGKey(CFG.random_seed)
    rng, init_rng = random.split(rng)
    # This is used for infeed conversion from feature dict <--> tuple
    train_keys = [
        'inputs', 'targets', 'inputs_position', 'targets_position',
        'inputs_segmentation', 'targets_segmentation'
    ]
    device_train_input_shape = tuple([
        (batch_size // topology.num_replicas,
         CFG.max_input_length if 'inputs' in k else CFG.max_target_length)
        for k in train_keys
    ])

    learning_rate_fn = train_lib.create_learning_rate_scheduler(
        factors=CFG.schedule,
        base_learning_rate=CFG.learning_rate,
        warmup_steps=CFG.warmup_steps)

    # First, we only abstractly initialize the optimizer and model parameters,
    # since the parameters may not even fit in device memory!
    # TODO(jekbradbury): make optimizer_defs compare by value so it can be created
    # in get_initial_params without causing pytree incompatibility
    optimizer_def = optim.Adafactor(CFG.learning_rate,
                                    decay_rate=0.8,
                                    step_offset=CFG.step_offset)
    initialize_params_fn = functools.partial(get_initial_params,
                                             config=CFG,
                                             transformer_config=eval_config,
                                             optimizer_def=optimizer_def)
    optimizer = jax.eval_shape(initialize_params_fn, init_rng)
    # tuple-like pytree leaves for global_arg_shapes
    optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape),
                                    optimizer)

    # Build parameter partition annotations for preserving partitions from train
    # to eval.
    if num_partitions > 1:
        optimizer_partitions = optimizer.restore_state(
            partitions.set_partitions(num_partitions, optimizer.state_dict()))
        per_host_optimizer_partitions = optimizer.restore_state(
            partitions.set_partitions(topology.per_host_num_partitions,
                                      optimizer.state_dict()))

    # Restore unreplicated optimizer + model state from last checkpoint.
    # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore
    existing_checkpoint_found = False
    if CFG.restore_checkpoints:
        existing_checkpoint_found = train_lib.checkpoint_exists(
            FLAGS.model_dir)
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)

    # Import a pretrained-T5 checkpoint only if we didn't import a local
    # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.)
    # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore
    if CFG.restore_t5_checkpoint and not existing_checkpoint_found:
        optimizer = checkpoint_importer.restore_from_t5_checkpoint(
            optimizer, CFG.restore_t5_checkpoint)

    if CFG.restore_t5_checkpoint or existing_checkpoint_found:
        if num_partitions > 1:
            # Until checkpoint/restore is sharded, the restored checkpoint is global
            # and we need to slice each sharded parameter into the chunk containing
            # only the partitions that are present on this host.
            def per_host_chunk(x, spec):
                if spec is None or spec is x:  # unsharded or not a parameter
                    return x
                if spec[0] == 1:
                    dim_size = x.shape[1]
                elif spec[1] == 1:
                    dim_size = x.shape[0]
                else:
                    raise NotImplementedError()
                chunk_size = (dim_size * topology.per_host_num_partitions //
                              num_partitions)
                lower = topology.per_replica_set_host_id * chunk_size
                upper = (topology.per_replica_set_host_id + 1) * chunk_size
                if spec[0] == 1:
                    return x[:, lower:upper]
                else:
                    return x[lower:upper]

            optimizer = jax.tree_multimap(per_host_chunk, optimizer,
                                          optimizer_partitions)
    else:
        # If pretraining and no checkpoint imported, we jit the (sharded-) init
        # function to minimize fragmentation. We use the same pmap(sharded_jit)
        # setup as the training step/loop to initialize everything "in-place" and
        # avoid communication or OOM.
        if num_partitions > 1:
            initialize_params_fn = sharded_jit(
                initialize_params_fn,
                in_parts=None,
                local_in_parts=None,
                out_parts=optimizer_partitions,
                local_out_parts=per_host_optimizer_partitions,
                # devices=one_replica_device_assignment,
            )
            initialize_params_fn = jax.pmap(initialize_params_fn,
                                            'batch',
                                            in_axes=0,
                                            axis_size=topology.num_replicas,
                                            devices=topology.device_assignment)
            init_rng = broadcast(init_rng)
            optimizer = initialize_params_fn(init_rng)
            # We maintain the optimizer in unbroadcasted form (i.e. with no leading
            # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg
            # out_axes=None.
            optimizer = train_lib.unbroadcast(optimizer)
        else:
            optimizer = jax.jit(initialize_params_fn)(init_rng)

    # ---------------------------------------------------------------------------
    # Compile multidevice versions of train/eval/predict step and cache init fn.
    # ---------------------------------------------------------------------------

    # We can use either a single train-step for a host training loop:

    # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs)
    #  --> new_optimizer, metrics, new_dropout_rng
    def p_train_step(optimizer, batch, prev_metrics, dropout_rng):
        return train_lib.train_step(optimizer,
                                    batch,
                                    prev_metrics,
                                    dropout_rng,
                                    config=train_config,
                                    learning_rate_fn=learning_rate_fn,
                                    num_microbatches=CFG.microbatches,
                                    label_smoothing=CFG.label_smoothing,
                                    z_loss=CFG.z_loss,
                                    use_bfloat16=CFG.use_bfloat16)

    if num_partitions > 1:
        p_train_step = sharded_jit(
            p_train_step,
            in_parts=(optimizer_partitions, None, None, None),
            local_in_parts=(per_host_optimizer_partitions, None, None, None),
            out_parts=(optimizer_partitions, None, None),
            local_out_parts=(per_host_optimizer_partitions, None, None))
    # TODO(levskaya): the in_axes spec below might be wrong, double-check.
    p_train_step = jax.pmap(p_train_step,
                            axis_name='batch',
                            in_axes=(None, 0, 0, 0),
                            donate_argnums=(0, ),
                            global_arg_shapes=(optimizer_shapes, None, None,
                                               None),
                            axis_size=topology.num_replicas,
                            devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # OR, we use an on-device loop that feeds the training step via infeed queue.
    def device_train_loop_cond(args):
        """Stopping criterion for on-device loop."""
        _, _, _, _, step, epoch = args
        return step // steps_per_epoch == epoch

    def device_train_loop_body(args):
        """On-device loop body."""
        optimizer, dropout_rngs, metrics, token, step, epoch = args
        # Ordering input data from infeed requires threading a symbolic token
        # through the computation.
        input_data, token = lax.infeed(token,
                                       shape=tuple([
                                           jax.ShapedArray(s, jnp.int32)
                                           for s in device_train_input_shape
                                       ]))
        # Rebuild input dict from infeed data tuple.
        batch = {k: v for k, v in zip(train_keys, input_data)}
        # Run the train_step function and return the loop state.
        optimizer, metrics, dropout_rngs = train_lib.train_step(
            optimizer,
            batch,
            metrics,
            dropout_rngs,
            train_config,
            learning_rate_fn,
            num_microbatches=CFG.microbatches,
            label_smoothing=CFG.label_smoothing,
            z_loss=CFG.z_loss)
        step += 1
        return optimizer, dropout_rngs, metrics, token, step, epoch

    def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch):
        # Create symbolic token for threading infeed data.
        token = lax.create_token(step)
        # Run on-device loop.
        optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
            device_train_loop_cond, device_train_loop_body,
            (optimizer, dropout_rngs, metrics, token, step, epoch))
        return optimizer, dropout_rngs, metrics, step

    if num_partitions > 1:
        device_train_loop = sharded_jit(
            device_train_loop,
            in_parts=(optimizer_partitions, None, None, None, None),
            local_in_parts=(per_host_optimizer_partitions, None, None, None,
                            None),
            out_parts=(optimizer_partitions, None, None, None),
            local_out_parts=(per_host_optimizer_partitions, None, None, None))
    p_train_epoch = jax.pmap(device_train_loop,
                             axis_name='batch',
                             in_axes=(None, 0, 0, None, None),
                             donate_argnums=(0, ),
                             global_arg_shapes=(optimizer_shapes, None, None,
                                                None, None),
                             axis_size=topology.num_replicas,
                             devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # Reduction psum for metric data.

    def p_allreduce_metrics(x):
        return lax.psum(x, axis_name='batch')

    if num_partitions > 1:
        p_allreduce_metrics = sharded_jit(
            p_allreduce_metrics,
            in_parts=None,
            local_in_parts=None,
            out_parts=None,
            local_out_parts=None,
            num_partitions=num_partitions,
            local_num_partitions=topology.per_host_num_partitions)
    p_allreduce_metrics = jax.pmap(p_allreduce_metrics,
                                   axis_name='batch',
                                   global_arg_shapes=None,
                                   axis_size=topology.num_replicas,
                                   devices=topology.device_assignment)

    # Training evaluation computation.

    # eval_step(params, batch, config, label_smoothing=0.0) --> metrics
    def p_eval_step(params, batch):
        return train_lib.eval_step(params,
                                   batch,
                                   config=eval_config,
                                   label_smoothing=CFG.label_smoothing)

    if num_partitions > 1:
        p_eval_step = sharded_jit(
            p_eval_step,
            in_parts=(optimizer_partitions.target, None),
            local_in_parts=(per_host_optimizer_partitions.target, None),
            out_parts=None,
            local_out_parts=None)
    p_eval_step = jax.pmap(p_eval_step,
                           axis_name='batch',
                           in_axes=(None, 0),
                           global_arg_shapes=(optimizer_shapes.target, None),
                           axis_size=topology.num_replicas,
                           devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # Fast autoregressive decoding loop.
    # For inference and model evaluation.

    # predict_step(inputs, params,
    #              eos_id, max_decode_len, config, beam_size=4) --> beam_seqs
    def p_pred_step(inputs, params):
        return train_lib.predict_step(inputs, params, eos_id,
                                      CFG.max_eval_target_length,
                                      predict_config, CFG.beam_size)

    if num_partitions > 1:
        p_pred_step = sharded_jit(
            p_pred_step,
            in_parts=(None, optimizer_partitions.target),
            local_in_parts=(None, per_host_optimizer_partitions.target),
            out_parts=None,
            local_out_parts=None)
    p_pred_step = jax.pmap(p_pred_step,
                           axis_name='batch',
                           in_axes=(0, None),
                           global_arg_shapes=(None, optimizer_shapes.target),
                           axis_size=topology.num_replicas,
                           devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

    # ---------------------------------------------------------------------------
    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    # There should be a unique dropout key for each replica represented on this
    # host, but the key should be the same for the same replica on other hosts.
    # Again, this is what the replica set abstraction is for.
    dropout_rngs = random.split(random.fold_in(rng, topology.replica_set_id),
                                topology.per_replica_set_num_replicas)
    # restore step from last checkpoint
    host_step = int(optimizer.state.step)
    empty_metrics = broadcast({
        'loss': 0.0,
        'accuracy': 0.0,
        'learning_rate': 0.0,
        'denominator': 0.0
    })
    if CFG.infeed:
        # TODO(jekbradbury): support something like this for the Python-loop case
        logging.info(
            'Precompiling training loop and moving optimizer to device.')
        optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs,
                                                 empty_metrics,
                                                 jnp.array(0,
                                                           dtype=jnp.int32), 1)
        optimizer = train_lib.unbroadcast(optimizer)
        metrics['loss'].block_until_ready()

    logging.info('Starting training loop.')

    local_devices = jax.local_devices()
    device_step = broadcast(host_step)
    first_epoch = host_step // steps_per_epoch

    # Main Loop over "epochs".
    train_iter = train_ds.as_numpy_iterator()
    for epoch in range(first_epoch, first_epoch + CFG.num_epochs):
        metrics = empty_metrics

        # NOTE: 'optimizer' is unbroadcast by construction at initialization or
        # when loading a checkpoint. It is maintained in 'unbroadcast' state to
        # enable the XLA cross-replica sharding optimization.  The broadcasting is
        # handled automatically by the pmap'd functions that use it.

        # Gather all task evaluation metrics.
        logging.info('Evaluating tasks.')
        if epoch == first_epoch + 1:
            train_lib.sync_devices()
        for task in eval_cache.tasks:
            logging.info('Evaluating task %s', task.name)
            all_predicted, all_bs = [], []
            for pred_batch in eval_cache.preprocessed_examples[task.name]:
                # Handle final odd-sized batch by padding instead of dropping it.
                input_batch, unpadded_batch_size = train_lib.pad_batch_to_size(
                    pred_batch['inputs'], per_replica_set_eval_batch_size)
                all_bs.append(unpadded_batch_size)
                # Split batch dimensions for pmap.
                input_batch = jax.tree_map(
                    lambda x: x.reshape((topology.per_replica_set_num_replicas,
                                         -1) + x.shape[1:]), input_batch)
                # Run fast inference on batch.
                all_predicted.append(p_pred_step(input_batch,
                                                 optimizer.target))

            # Pad out the number of batches so each host has the same number.
            max_host_batch_number = np.max(
                eval_cache.preprocessed_batch_sizes[task.name])
            batch_shortfall = max_host_batch_number - len(all_predicted)
            if batch_shortfall > 0:
                # TODO(levskaya): Fix for case of entirely empty all_predicted.
                # To make sure the cross-host barriers work, we run the program the same
                # number of times on all hosts. The results of this call is ignored, and
                # the predictions are populated with zeros instead.
                p_pred_step(input_batch, optimizer.target)  # Dummy call.
                all_predicted.extend([jnp.zeros_like(all_predicted[0])] *
                                     batch_shortfall)
                all_bs.extend([0] * batch_shortfall)
            all_predicted = jnp.concatenate(all_predicted)
            all_bs = jnp.array(all_bs)

            # Collect all batches from across hosts and reverse sharding.
            all_predicted = train_lib.host_allgather(
                all_predicted, topology.num_replica_sets,
                topology.replica_set_id, topology.per_replica_set_host_id == 0)
            seqlength = all_predicted.shape[-1]
            total_examples = np.sum(
                train_lib.host_allgather(
                    all_bs, topology.num_replica_sets, topology.replica_set_id,
                    topology.per_replica_set_host_id == 0))
            del all_bs
            assert total_examples == len(eval_cache.examples[task.name]), (
                'Total number of batches incorrect for task %s.' % task.name)
            # De-shard the collected predicted tokens and remove padding.
            all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape(
                -1, seqlength)[:total_examples]

            # We now run the post-processing and metric-fns on a single host.
            if jax.host_id() == 0:
                assert eval_summary_writer
                raw_predictions = []
                for tokens in all_predicted:
                    raw_predictions.append(decode_tokens(tokens))

                # post-process predictions for metric fns
                predictions = [
                    task.postprocess_fn(p, example=ex) for p, ex in zip(
                        raw_predictions, eval_cache.examples[task.name])
                ]

                for metric_fn in task.metric_fns:
                    scores = metric_fn(eval_cache.targets[task.name],
                                       predictions)
                    for metric_name, metric_value in scores.items():
                        tag = f'eval/{task.name}/{metric_name}'
                        eval_summary_writer.scalar(tag, metric_value,
                                                   host_step)
                        logging.info('EVAL %s at step %d: %.3f', tag,
                                     host_step, metric_value)
                    eval_summary_writer.flush()

                # Save text samples for tensorboard.
                exemplars = ''
                for n in np.random.choice(np.arange(len(predictions)), 8):
                    tgt_txt = tf.compat.as_text(
                        eval_cache.examples[task.name][n]['targets_plaintext'])
                    pred_txt = raw_predictions[n]
                    exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n'
                                  f'target: {tgt_txt}\n\n'
                                  f'prediction: {pred_txt}\n\n')
                eval_summary_writer.text(f'{task.name} samples', exemplars,
                                         host_step)
                eval_summary_writer.flush()

        # Take an Xprof trace after the first loop has compiled everything.
        if epoch == first_epoch + 1:
            train_lib.sync_devices()

        # For on-device loop, we launch the computation before feeding data.
        logging.info('BEGIN Train loop.')
        if CFG.infeed:
            optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
                optimizer, dropout_rngs, metrics,
                train_lib.unbroadcast(device_step), epoch)
            optimizer = train_lib.unbroadcast(optimizer)

        # Epoch loop.
        while int(host_step // steps_per_epoch) == epoch:
            batch = next(train_iter)
            batch = jax.tree_map(
                lambda x: x.reshape(
                    (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
                batch)
            # Feed the on-device training loop.
            if CFG.infeed:
                for i, device in enumerate(local_devices):
                    # When using infeed to provide data to the computation, we're on our
                    # own for feeding the right values to the right devices. Each device
                    # should get the minibatch corresponding to its replica, a slice of
                    # the larger batch corresponding to the host's replica set.
                    if device.platform == 'tpu':
                        device_coords = (*device.coords, device.id % 2)
                    else:
                        device_coords = (device.host_id, i)
                    per_replica_set_device_coords = tuple(
                        dc % prsm for dc, prsm in zip(
                            device_coords, topology.per_replica_set_mesh))
                    per_replica_set_replica_coords = tuple(
                        prsdc // prm
                        for prsdc, prm in zip(per_replica_set_device_coords,
                                              topology.per_replica_mesh))
                    per_replica_set_replica_id = 0
                    for prsm, prm, prsrc in zip(
                            topology.per_replica_set_mesh,
                            topology.per_replica_mesh,
                            per_replica_set_replica_coords):
                        per_replica_set_replica_id = (
                            per_replica_set_replica_id * prsm // prm + prsrc)
                    input_tuple = tuple([
                        batch[k][per_replica_set_replica_id]
                        for k in train_keys
                    ])
                    # Safety check: infeed does not check shape or types but requires
                    # them to agree with on-device spec, otherwise the queue and program
                    # stalls.
                    tuple_shapes = jax.tree_map(jnp.shape, input_tuple)
                    tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple)
                    assert tuple_shapes == device_train_input_shape, (
                        'infeed shape error %s != %s' %
                        (tuple_shapes, device_train_input_shape))
                    assert tuple(set(tuple_dtypes)) == (jnp.int32,), \
                        ('infeed dtype error %s not all of type %s' % (
                            tuple_dtypes, jnp.int32))
                    infeed_pool.submit(
                        functools.partial(device.transfer_to_infeed,
                                          input_tuple))
            # Host training loop.
            else:
                optimizer, metrics, dropout_rngs = p_train_step(
                    optimizer, batch, metrics, dropout_rngs)
                optimizer = train_lib.unbroadcast(optimizer)
            host_step += 1
        logging.info('END Train loop.')

        # Maybe save a checkpoint on one host.
        if (CFG.save_checkpoints
                and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1
                and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step)

        # Gather training metrics.
        metrics = p_allreduce_metrics(metrics)
        metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics)
        denominator = metrics.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics)  # pylint: disable=cell-var-from-loop
        logging.info('train in step: %s, %s', host_step, summary)
        if jax.host_id() == 0:
            assert train_summary_writer
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, host_step)
            train_summary_writer.flush()

        # Gather training evaluation metrics.
        logging.info('Gathering training evaluation metrics.')
        eval_metrics = []
        eval_iter = eval_ds.as_numpy_iterator()
        for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(
                lambda x: x.reshape(
                    (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
                eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        # average metrics across devices
        eval_metrics = p_allreduce_metrics(eval_metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        # average metrics across steps
        eval_metrics = jax.tree_map(np.sum, eval_metrics)
        eval_denominator = eval_metrics.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics)
        logging.info('eval in step: %s, %s', host_step, eval_summary)
        if jax.host_id() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, host_step)
            eval_summary_writer.flush()

    # Wait until computations are done before exiting
    logging.info('Finished.')
    train_lib.sync_devices()
    # Shut down the infeed threadpool.
    if CFG.infeed:
        infeed_pool.shutdown()