def testPyTreeArgs(self): if jax.device_count() < 2: raise SkipTest def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape) a = (_make_arg(4, 4), 1) b = _make_arg(4, 4) c = [2, (_make_arg(4, 4), _make_arg(4, 4))] in_parts = (None, P(2, 1), [None, P(2, 1)]) out_parts = P(2, 1) result = sharded_jit(f, in_parts, out_parts)(a, b, c) expected = f(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 2) in_parts = None result = sharded_jit(f, in_parts, out_parts)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 2)
def testShardingConstraint(self): def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(np.prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertEqual(actual.device_buffers[0].shape().dimensions(), (4, 8)) self.assertEqual(actual.device_buffers[1].shape().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def test_sharded_jit_with_sharding_constraint(self): """A sharding constraint in the middle.""" def jax_func(x, y): logits1 = jnp.dot(x, y) return jnp.sin( sharded_jit.with_sharding_constraint(logits1, P(2, 1))) sharded_jax_func = sharded_jit.sharded_jit(jax_func, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2)) xshape = (6, 8) x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (8, 10) y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape) self._check_sharding_annotations( sharded_jax_func, [x, y], expected=[ r"f32\[6,8\].*sharding={devices=\[1,2\]", r"f32\[8,10\].*sharding={devices=\[2,1\]", r"f32\[6,10\].*sharding={devices=\[2,1\]", r"f32\[6,10\].*sine.*sharding={devices=\[1,2\]" ], expected_opt=[ # TODO(necula): relax ordering r"f32\[4,10\].*sharding={devices=\[2,1\]", r"f32\[6,4\].*sharding={devices=\[1,2\]", ], num_partitions=2)
def testPyTreeArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape) a = (_make_arg(2, 4, 4), _make_arg(2)) b = _make_arg(2, 4, 4) c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4))) in_parts = (None, P(2, 1), (None, P(2, 1))) out_parts = P(2, 1) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c) expected = pmap(f)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def test_sharded_jit_in_out(self): """Test input and output sharding annotations.""" sharded_jax_func = sharded_jit.sharded_jit(jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2)) xshape = (3, 8) x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (8, 5) y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape) self._check_sharding_annotations( sharded_jax_func, [x, y], expected=[ r"f32\[3,8\].*sharding={devices=\[1,2\]", r"f32\[8,5\].*sharding={devices=\[2,1\]", r"f32\[3,5\].*sharding={devices=\[1,2\]" ], expected_opt=[ # TODO(necula): relax ordering r"f32\[4,5\].*sharding={devices=\[2,1\]", r"f32\[3,4\].*sharding={devices=\[1,2\]", r"f32\[3,5\].*fusion", r"f32\[3,5\].*all-reduce", ], num_partitions=2)
def testCompilationCache(self): f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2, ) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.assert_num_jit_and_pmap_compilations(1): sharded_f(x) sharded_f(x)
def testCompilationCache(self): f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2, ) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: sharded_f(x) sharded_f(x) self.assertEqual(count[0], 1)
def testPyTreeOutputs(self): if jax.device_count() < 2: raise SkipTest def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1), ) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = sharded_jit(f, in_parts, out_parts)(x) expected = f(x) self.assertAllClose(result, expected, check_dtypes=False) out_parts = None result = sharded_jit(f, in_parts, out_parts)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is # the default. self.assertEqual( getattr(actual.device_buffers[0], "xla_shape", actual.device_buffers[0].shape)().dimensions(), (4, 8)) self.assertEqual( getattr(actual.device_buffers[1], "xla_shape", actual.device_buffers[1].shape)().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def testPyTreeOutputs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) x = np.arange(np.prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x) expected = pmap(f)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testNestedShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") shape = (8, 8) @jit def f(x): return lax.while_loop( lambda i: i[0, 0] < 10., lambda i: with_sharding_constraint(i + 1., P(2, 1)), x) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) expected = x + 10. actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2)
def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32): """Compares pmap(sharded_jit(f, ...)) to pmap(f)""" shape = (2, 4, 4) num_shards = shape[0] * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) x = np.arange(np.prod(shape, dtype=dtype)).reshape(shape) y = x + 1 result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y) expected = pmap(f)(x, y) self.assertAllClose(result, expected, check_dtypes=False) flat_result = tree_util.tree_flatten(result)[0] for r in flat_result: self.assertTrue(isinstance(r, pxla.ShardedDeviceArray)) self.assertEqual(len(r.device_buffers), num_shards)
def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.sum(args) shape = (2, 4, 4) args = [np.arange(np.prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1),) * num_args out_partitions = None result = pmap(sharded_jit( f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def test_sharded_jit_replicated(self): """A replicated input and output.""" sharded_jax_func = sharded_jit.sharded_jit( jnp.dot, in_parts=(P(1, 2), None), out_parts=None) xshape = (3, 8) x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (8, 5) y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape) self._check_sharding_annotations( sharded_jax_func, [x, y], expected=[ r"f32\[3,8\].*sharding={devices=\[1,2\]", r"f32\[8,5\].*sharding={replicated}", r"f32\[3,5\].*sharding={replicated}" ], expected_opt=[ # TODO(necula): relax ordering r"f32\[8,5\].*sharding={replicated}", r"f32\[3,4\].*sharding={devices=\[1,2\]", ], num_partitions=2)
def main(argv): global BLEU_THRESHOLD_REACHED if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') init_mllogger() mllogger.event('cache_clear') mllogger.start('init_start') mllogger.event('submission_org', 'Google') mllogger.event('submission_platform', 'TPUv3-{}'.format(jax.device_count())) mllogger.event('submission_division', 'closed') mllogger.event('submission_status', 'research') mllogger.event('submission_benchmark', 'transformer') mllogger.event('train_samples', input_pipeline.N_TRAIN) mllogger.event('eval_samples', input_pipeline.N_EVAL) tf.enable_v2_behavior() # Use hardware RNG for bernoulli randoms in dropout mask creation. if FLAGS.hardware_rng: models.set_hardware_bernoulli() num_partitions = FLAGS.num_partitions batch_size = FLAGS.batch_size if batch_size is None: batch_size = min(16 * jax.device_count() // num_partitions, 2048) mllogger.event('global_batch_size', batch_size) num_eval_steps = FLAGS.num_eval_steps max_target_length = FLAGS.max_target_length max_eval_target_length = FLAGS.max_eval_target_length max_length = max(max_target_length, max_eval_target_length) mllogger.event('max_sequence_length', max_length, metadata={'method': 'discard'}) if FLAGS.random_seed is not None: seed = FLAGS.random_seed else: seed = np.uint32(time.time() if jax.host_id() == 0 else 0) seed = per_host_sum_pmap(seed) mllogger.event('seed', int(seed)) steps_per_epoch = int(math.ceil(input_pipeline.N_TRAIN / batch_size)) logging.info('steps per epoch: %d', steps_per_epoch) num_replicas = jax.local_device_count() // num_partitions device_train_input_shape = (batch_size // (num_replicas * jax.host_count()), max_target_length) # This is per-host; in principle 64/replica or more should fit eval_batch_size = min( 32 * num_replicas, int( math.ceil(input_pipeline.N_EVAL / (num_replicas * jax.host_count()))) * num_replicas) logging.info('eval batch size: %d', eval_batch_size) pred_batches = int( math.ceil(input_pipeline.N_EVAL / (jax.host_count() * eval_batch_size))) logging.info('pred batches: %d', pred_batches) broadcast = functools.partial(_broadcast, num_replicas=num_replicas, num_partitions=num_partitions) if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') # MLPerf 2020 WMT en-de dataset uses a custom T2T dataset: # Shared 32K subword tokenization # 256-length packed training examples from WMT17 # 97-length unpacked evaluation examples from WMT14 train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] encoder = mlperf_encoder.SubwordTextEncoder(filename=FLAGS.vocab_path) input_encoder = encoder target_encoder = encoder vocab_size = input_encoder.vocab_size output_vocab_size = target_encoder.vocab_size input_shape = (batch_size, max_target_length) target_shape = (batch_size, max_target_length) transformer_kwargs = flax.core.FrozenDict({ 'vocab_size': vocab_size, 'output_vocab_size': output_vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max_length, 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, 'num_partitions': num_partitions, }) rng = random.PRNGKey(seed) rng, init_rng = random.split(rng) model, cache_def = create_model(init_rng, tuple(input_shape), tuple(target_shape), transformer_kwargs) mllogger.event('opt_name', 'adam') if batch_size < 1024: learning_rate = 4.0 # 0.0625 warmup_steps = 1000 beta1 = 0.9 beta2 = 0.98 if batch_size < 2048: learning_rate = 2.0 warmup_steps = 500 # ?? beta1 = 0.9 # ?? beta2 = 0.98 # ?? else: learning_rate = 3.3092157691415953 warmup_steps = 664 beta1 = 0.9086575725261137 beta2 = 0.9198719118104947 epsilon = 1e-9 if FLAGS.learning_rate is not None: learning_rate = FLAGS.learning_rate mllogger.event('opt_adam_beta_1', beta1) mllogger.event('opt_adam_beta_2', beta2) mllogger.event('opt_adam_epsilon', epsilon) optimizer_def = optim.Adam(learning_rate, beta1=beta1, beta2=beta2, eps=epsilon, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(model) del model # don't keep a copy of the initial model # Build parameter partition annotations for preserving partitions from train # to eval. partition_rules = [ (('encoder', 'posembed_input'), partitions.empty_dict), (('decoder', 'posembed_targets'), partitions.empty_dict), (('embedding', ), partitions.spec(num_partitions, 1)), ((r'LayerNorm_\d+', '(bias|scale)'), None), ((r'encoder(decoder)?_norm', '(bias|scale)'), None), ((r'MultiHeadDotProductAttention_\d+', '(query|key|value)', 'kernel'), partitions.spec(1, num_partitions, 1)), ((r'MultiHeadDotProductAttention_\d+', 'out', 'kernel'), partitions.spec(num_partitions, 1, 1)), ((r'MlpBlock_\d+', r'Dense_\d+', 'bias'), None), ((r'MlpBlock_\d+', 'Dense_0', 'kernel'), partitions.spec(1, num_partitions)), ((r'MlpBlock_\d+', 'Dense_1', 'kernel'), partitions.spec(num_partitions, 1)), (('state', 'step'), None), ] optimizer_partitions = optimizer.restore_state( partitions.set_partitions(partition_rules, optimizer.state_dict())) optimizer = broadcast(optimizer) empty_metrics = broadcast({'loss': 0.0, 'accuracy': 0, 'denominator': 0}) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate, warmup_steps=warmup_steps, hidden_size=transformer_kwargs['qkv_dim']) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch', in_axes=(None, 0, 0, 0)) if num_partitions > 1: sharded_predict_step = sharded_jit( predict_step, in_parts=(None, optimizer_partitions.target, None), out_parts=None) else: sharded_predict_step = predict_step if FLAGS.extra_eval_metrics: p_eval_step = jax.pmap(eval_step, axis_name='batch', in_axes=(None, 0)) p_pred_step = jax.pmap(sharded_predict_step, axis_name='batch', in_axes=(0, None, None)) p_allreduce_metrics = jax.pmap(functools.partial(lax.psum, axis_name='batch'), axis_name='batch') def device_train_loop_cond(args): _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): optimizer, dropout_rngs, metrics, token, step, epoch = args input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray( device_train_input_shape, jnp.int32) for _ in train_keys ])) batch = {k: v for k, v in zip(train_keys, input_data)} optimizer, metrics, dropout_rngs = train_step(optimizer, batch, metrics, learning_rate_fn, dropout_rng=dropout_rngs) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): token = lax.create_token(step) optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit(device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None)) p_allreduce_metrics_train = functools.partial(lax.psum, axis_name='batch') if num_partitions > 1: p_allreduce_metrics_train = sharded_jit(p_allreduce_metrics_train, in_parts=None, out_parts=None, num_partitions=num_partitions) p_allreduce_metrics_train = jax.pmap(p_allreduce_metrics_train, axis_name='batch') # Precompile all needed computations with fake data so as not to include # compilation time in MLPerf metrics. if FLAGS.precompile: logging.info('precompiling step/epoch functions') if FLAGS.infeed: # the device training loop condition will immediately be false, but # the optimizer tree will be resharded here optimizer, *_ = p_train_epoch(unbroadcast(optimizer), random.split(rng, num_replicas), empty_metrics, jnp.array(0, dtype=jnp.int32), 1) else: metrics = empty_metrics train_input_shape = (num_replicas, batch_size // num_replicas, input_pipeline.MAX_TRAIN_LEN) fake_batch = { k: jnp.ones(train_input_shape, jnp.int32) for k in train_keys } p_train_step(unbroadcast(optimizer), fake_batch, metrics, dropout_rng=random.split(rng, num_replicas)) eval_input_shape = (num_replicas, eval_batch_size // num_replicas, input_pipeline.MAX_EVAL_LEN) fake_eval_batch = { 'inputs': jnp.ones(eval_input_shape, jnp.int32), 'targets': jnp.ones(eval_input_shape, jnp.int32), } if FLAGS.extra_eval_metrics: p_eval_step(unbroadcast(optimizer.target), fake_eval_batch) fake_cache = cache_def.initialize_cache( (eval_input_shape[1], FLAGS.max_predict_length)) p_pred_step(fake_eval_batch['inputs'], unbroadcast(optimizer.target), fake_cache) time.sleep(20) sync_devices() fake_bleu_1 = np.zeros((4, ), dtype=np.int32) fake_bleu_2 = np.zeros((), dtype=np.int32) per_host_sum_pmap((fake_bleu_1, fake_bleu_1, fake_bleu_2, fake_bleu_2)) sync_devices() p_allreduce_metrics_train(empty_metrics) sync_devices() logging.info('finished precompiling step/epoch functions') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, num_replicas) # Record time-0 metrics for proper tensorboard plot x-axis scaling. if jax.host_id() == 0: if FLAGS.compute_train_metrics: train_summary_writer.scalar('loss', 9.999, 0) train_summary_writer.scalar('accuracy', 0.0, 0) train_summary_writer.flush() eval_summary_writer.scalar('bleu', 0.0, 0) eval_summary_writer.flush() train_ds = input_pipeline.get_wmt_dataset(batch_size=batch_size // jax.host_count(), train=True) eval_ds = input_pipeline.get_wmt_dataset(batch_size=eval_batch_size, train=False) train_iter = iter(train_ds) eval_iter = iter(eval_ds) local_devices = jax.local_devices() host_step, device_step = 0, broadcast(0) gc.disable() mllogger.end('init_stop') if jax.host_id() == 0: mllogger.start('run_start') for epoch in range(FLAGS.num_epochs): if jax.host_id() == 0 and not BLEU_THRESHOLD_REACHED: mllogger.start('block_start', metadata={ 'first_epoch_num': epoch + 1, 'epoch_count': 1 }) metrics = empty_metrics if FLAGS.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( unbroadcast(optimizer), dropout_rngs, metrics, unbroadcast(device_step), epoch) while int(host_step // steps_per_epoch) == epoch: # pylint: disable=protected-access batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # Shard data to devices and do a training step. batch = jax.tree_map( lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), batch) if FLAGS.infeed: for i, device in enumerate(local_devices): replica_id = i // num_partitions input_tuple = tuple( [batch[k][replica_id] for k in train_keys]) assert input_tuple[0].shape == device_train_input_shape, ( 'infeed shape error %s != %s' % (input_tuple[0].shape, device_train_input_shape)) assert input_tuple[0].dtype == jnp.int32, ( 'infeed dtype error %s != %s' % (input_tuple[0].dtype, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) else: optimizer, metrics, dropout_rngs = p_train_step( unbroadcast(optimizer), batch, metrics, dropout_rng=dropout_rngs) host_step += 1 if FLAGS.compute_train_metrics: metrics = p_allreduce_metrics_train(metrics) # Schedule training metric handling. summary_thread.submit( functools.partial(write_train_summary, metrics, train_summary_writer, host_step)) # Optional, extra evaluation metrics. if FLAGS.extra_eval_metrics: eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(unbroadcast(optimizer.target), eval_batch) eval_metrics.append(metrics) eval_metrics = p_allreduce_metrics(eval_metrics) # Schedule metric summarization/logging. summary_thread.submit( functools.partial(write_eval_summary, eval_metrics, eval_summary_writer, host_step)) # Translation and BLEU Score. all_predicted, all_targets, all_bs = [], [], [] for i in range(pred_batches): # pylint: disable=protected-access pred_batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size != eval_batch_size: pred_batch = jax.tree_map( lambda x: pad_examples(x, eval_batch_size), pred_batch) pred_batch = jax.tree_map( lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache = cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length)) all_predicted.append( p_pred_step(pred_batch['inputs'], unbroadcast(optimizer.target), cache)) all_targets.append(pred_batch['targets']) all_bs.append(cur_pred_batch_size) # Schedule BLEU calculation and summarization/logging. # We use the ICI as part of BLEU score computation, so we call this from the # main thread so the BLEU pmap runs before the next train epoch pmap write_predict_summary(all_predicted, all_targets, all_bs, target_encoder, eval_summary_writer, epoch, host_step, summary_thread) # Wait until computations are done before exiting sync_devices() if jax.host_id() == 0: summary_thread.shutdown() if not BLEU_THRESHOLD_REACHED: mllogger.end('run_stop', metadata={'status': 'aborted'})
def main(argv): global CFG CFG = FLAGS.config if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16. _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16)) # Use hardware RNG for bernoulli randoms in dropout mask creation. if CFG.hardware_rng: models.set_hardware_bernoulli() if 'module_import' in CFG and CFG.module_import: for module in CFG.module_import: importlib.import_module(module) if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs: t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs) num_partitions = CFG.num_partitions topology = train_lib.compute_multihost_topology(num_partitions) batch_size = CFG.batch_size eval_batch_size = CFG.eval_batch_size per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets if batch_size % topology.num_replicas: raise ValueError( 'Batch size must be divisible by the number of replicas.') steps_per_epoch = CFG.steps_per_epoch logging.info('steps per epoch: %d', steps_per_epoch) broadcast = functools.partial( train_lib.broadcast, num_replicas=topology.per_replica_set_num_replicas, num_partitions=topology.per_host_num_partitions, devices=topology.this_host_device_assignment) if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) tf.io.gfile.copy(FLAGS['config'].config_filename, os.path.join(FLAGS.model_dir, 'config.py'), overwrite=True) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync if CFG.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache( CFG, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id) vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name) encoder = vocab.tf_tokenizer eos_id = vocab.tokenizer.eos_id() def decode_tokens(toks, eos_id=eos_id, max_id=32000): """Decode tokens back to unicode.""" del eos_id # TODO(levskaya): T5 doesn't seem to emit EOS tokens? double check this # is the best decoding function or just switch to using tf_decode. # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) valid_toks = toks.astype(np.int32) valid_toks[valid_toks >= max_id] = 3 return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') train_config, eval_config, predict_config = get_configs(CFG) rng = random.PRNGKey(CFG.random_seed) rng, init_rng = random.split(rng) # This is used for infeed conversion from feature dict <--> tuple train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] device_train_input_shape = tuple([ (batch_size // topology.num_replicas, CFG.max_input_length if 'inputs' in k else CFG.max_target_length) for k in train_keys ]) learning_rate_fn = train_lib.create_learning_rate_scheduler( factors=CFG.schedule, base_learning_rate=CFG.learning_rate, warmup_steps=CFG.warmup_steps) # First, we only abstractly initialize the optimizer and model parameters, # since the parameters may not even fit in device memory! # TODO(jekbradbury): make optimizer_defs compare by value so it can be created # in get_initial_params without causing pytree incompatibility optimizer_def = optim.Adafactor(CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset) initialize_params_fn = functools.partial(get_initial_params, config=CFG, transformer_config=eval_config, optimizer_def=optimizer_def) optimizer = jax.eval_shape(initialize_params_fn, init_rng) # tuple-like pytree leaves for global_arg_shapes optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape), optimizer) # Build parameter partition annotations for preserving partitions from train # to eval. if num_partitions > 1: optimizer_partitions = optimizer.restore_state( partitions.set_partitions(num_partitions, optimizer.state_dict())) per_host_optimizer_partitions = optimizer.restore_state( partitions.set_partitions(topology.per_host_num_partitions, optimizer.state_dict())) # Restore unreplicated optimizer + model state from last checkpoint. # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore existing_checkpoint_found = False if CFG.restore_checkpoints: existing_checkpoint_found = train_lib.checkpoint_exists( FLAGS.model_dir) optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Import a pretrained-T5 checkpoint only if we didn't import a local # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.) # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore if CFG.restore_t5_checkpoint and not existing_checkpoint_found: optimizer = checkpoint_importer.restore_from_t5_checkpoint( optimizer, CFG.restore_t5_checkpoint) if CFG.restore_t5_checkpoint or existing_checkpoint_found: if num_partitions > 1: # Until checkpoint/restore is sharded, the restored checkpoint is global # and we need to slice each sharded parameter into the chunk containing # only the partitions that are present on this host. def per_host_chunk(x, spec): if spec is None or spec is x: # unsharded or not a parameter return x if spec[0] == 1: dim_size = x.shape[1] elif spec[1] == 1: dim_size = x.shape[0] else: raise NotImplementedError() chunk_size = (dim_size * topology.per_host_num_partitions // num_partitions) lower = topology.per_replica_set_host_id * chunk_size upper = (topology.per_replica_set_host_id + 1) * chunk_size if spec[0] == 1: return x[:, lower:upper] else: return x[lower:upper] optimizer = jax.tree_multimap(per_host_chunk, optimizer, optimizer_partitions) else: # If pretraining and no checkpoint imported, we jit the (sharded-) init # function to minimize fragmentation. We use the same pmap(sharded_jit) # setup as the training step/loop to initialize everything "in-place" and # avoid communication or OOM. if num_partitions > 1: initialize_params_fn = sharded_jit( initialize_params_fn, in_parts=None, local_in_parts=None, out_parts=optimizer_partitions, local_out_parts=per_host_optimizer_partitions, # devices=one_replica_device_assignment, ) initialize_params_fn = jax.pmap(initialize_params_fn, 'batch', in_axes=0, axis_size=topology.num_replicas, devices=topology.device_assignment) init_rng = broadcast(init_rng) optimizer = initialize_params_fn(init_rng) # We maintain the optimizer in unbroadcasted form (i.e. with no leading # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg # out_axes=None. optimizer = train_lib.unbroadcast(optimizer) else: optimizer = jax.jit(initialize_params_fn)(init_rng) # --------------------------------------------------------------------------- # Compile multidevice versions of train/eval/predict step and cache init fn. # --------------------------------------------------------------------------- # We can use either a single train-step for a host training loop: # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs) # --> new_optimizer, metrics, new_dropout_rng def p_train_step(optimizer, batch, prev_metrics, dropout_rng): return train_lib.train_step(optimizer, batch, prev_metrics, dropout_rng, config=train_config, learning_rate_fn=learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss, use_bfloat16=CFG.use_bfloat16) if num_partitions > 1: p_train_step = sharded_jit( p_train_step, in_parts=(optimizer_partitions, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None), out_parts=(optimizer_partitions, None, None), local_out_parts=(per_host_optimizer_partitions, None, None)) # TODO(levskaya): the in_axes spec below might be wrong, double-check. p_train_step = jax.pmap(p_train_step, axis_name='batch', in_axes=(None, 0, 0, 0), donate_argnums=(0, ), global_arg_shapes=(optimizer_shapes, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # OR, we use an on-device loop that feeds the training step via infeed queue. def device_train_loop_cond(args): """Stopping criterion for on-device loop.""" _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape ])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): # Create symbolic token for threading infeed data. token = lax.create_token(step) # Run on-device loop. optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit( device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None), local_out_parts=(per_host_optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None), donate_argnums=(0, ), global_arg_shapes=(optimizer_shapes, None, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Reduction psum for metric data. def p_allreduce_metrics(x): return lax.psum(x, axis_name='batch') if num_partitions > 1: p_allreduce_metrics = sharded_jit( p_allreduce_metrics, in_parts=None, local_in_parts=None, out_parts=None, local_out_parts=None, num_partitions=num_partitions, local_num_partitions=topology.per_host_num_partitions) p_allreduce_metrics = jax.pmap(p_allreduce_metrics, axis_name='batch', global_arg_shapes=None, axis_size=topology.num_replicas, devices=topology.device_assignment) # Training evaluation computation. # eval_step(params, batch, config, label_smoothing=0.0) --> metrics def p_eval_step(params, batch): return train_lib.eval_step(params, batch, config=eval_config, label_smoothing=CFG.label_smoothing) if num_partitions > 1: p_eval_step = sharded_jit( p_eval_step, in_parts=(optimizer_partitions.target, None), local_in_parts=(per_host_optimizer_partitions.target, None), out_parts=None, local_out_parts=None) p_eval_step = jax.pmap(p_eval_step, axis_name='batch', in_axes=(None, 0), global_arg_shapes=(optimizer_shapes.target, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Fast autoregressive decoding loop. # For inference and model evaluation. # predict_step(inputs, params, # eos_id, max_decode_len, config, beam_size=4) --> beam_seqs def p_pred_step(inputs, params): return train_lib.predict_step(inputs, params, eos_id, CFG.max_eval_target_length, predict_config, CFG.beam_size) if num_partitions > 1: p_pred_step = sharded_jit( p_pred_step, in_parts=(None, optimizer_partitions.target), local_in_parts=(None, per_host_optimizer_partitions.target), out_parts=None, local_out_parts=None) p_pred_step = jax.pmap(p_pred_step, axis_name='batch', in_axes=(0, None), global_arg_shapes=(None, optimizer_shapes.target), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # --------------------------------------------------------------------------- # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. # There should be a unique dropout key for each replica represented on this # host, but the key should be the same for the same replica on other hosts. # Again, this is what the replica set abstraction is for. dropout_rngs = random.split(random.fold_in(rng, topology.replica_set_id), topology.per_replica_set_num_replicas) # restore step from last checkpoint host_step = int(optimizer.state.step) empty_metrics = broadcast({ 'loss': 0.0, 'accuracy': 0.0, 'learning_rate': 0.0, 'denominator': 0.0 }) if CFG.infeed: # TODO(jekbradbury): support something like this for the Python-loop case logging.info( 'Precompiling training loop and moving optimizer to device.') optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs, empty_metrics, jnp.array(0, dtype=jnp.int32), 1) optimizer = train_lib.unbroadcast(optimizer) metrics['loss'].block_until_ready() logging.info('Starting training loop.') local_devices = jax.local_devices() device_step = broadcast(host_step) first_epoch = host_step // steps_per_epoch # Main Loop over "epochs". train_iter = train_ds.as_numpy_iterator() for epoch in range(first_epoch, first_epoch + CFG.num_epochs): metrics = empty_metrics # NOTE: 'optimizer' is unbroadcast by construction at initialization or # when loading a checkpoint. It is maintained in 'unbroadcast' state to # enable the XLA cross-replica sharding optimization. The broadcasting is # handled automatically by the pmap'd functions that use it. # Gather all task evaluation metrics. logging.info('Evaluating tasks.') if epoch == first_epoch + 1: train_lib.sync_devices() for task in eval_cache.tasks: logging.info('Evaluating task %s', task.name) all_predicted, all_bs = [], [] for pred_batch in eval_cache.preprocessed_examples[task.name]: # Handle final odd-sized batch by padding instead of dropping it. input_batch, unpadded_batch_size = train_lib.pad_batch_to_size( pred_batch['inputs'], per_replica_set_eval_batch_size) all_bs.append(unpadded_batch_size) # Split batch dimensions for pmap. input_batch = jax.tree_map( lambda x: x.reshape((topology.per_replica_set_num_replicas, -1) + x.shape[1:]), input_batch) # Run fast inference on batch. all_predicted.append(p_pred_step(input_batch, optimizer.target)) # Pad out the number of batches so each host has the same number. max_host_batch_number = np.max( eval_cache.preprocessed_batch_sizes[task.name]) batch_shortfall = max_host_batch_number - len(all_predicted) if batch_shortfall > 0: # TODO(levskaya): Fix for case of entirely empty all_predicted. # To make sure the cross-host barriers work, we run the program the same # number of times on all hosts. The results of this call is ignored, and # the predictions are populated with zeros instead. p_pred_step(input_batch, optimizer.target) # Dummy call. all_predicted.extend([jnp.zeros_like(all_predicted[0])] * batch_shortfall) all_bs.extend([0] * batch_shortfall) all_predicted = jnp.concatenate(all_predicted) all_bs = jnp.array(all_bs) # Collect all batches from across hosts and reverse sharding. all_predicted = train_lib.host_allgather( all_predicted, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0) seqlength = all_predicted.shape[-1] total_examples = np.sum( train_lib.host_allgather( all_bs, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0)) del all_bs assert total_examples == len(eval_cache.examples[task.name]), ( 'Total number of batches incorrect for task %s.' % task.name) # De-shard the collected predicted tokens and remove padding. all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape( -1, seqlength)[:total_examples] # We now run the post-processing and metric-fns on a single host. if jax.host_id() == 0: assert eval_summary_writer raw_predictions = [] for tokens in all_predicted: raw_predictions.append(decode_tokens(tokens)) # post-process predictions for metric fns predictions = [ task.postprocess_fn(p, example=ex) for p, ex in zip( raw_predictions, eval_cache.examples[task.name]) ] for metric_fn in task.metric_fns: scores = metric_fn(eval_cache.targets[task.name], predictions) for metric_name, metric_value in scores.items(): tag = f'eval/{task.name}/{metric_name}' eval_summary_writer.scalar(tag, metric_value, host_step) logging.info('EVAL %s at step %d: %.3f', tag, host_step, metric_value) eval_summary_writer.flush() # Save text samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): tgt_txt = tf.compat.as_text( eval_cache.examples[task.name][n]['targets_plaintext']) pred_txt = raw_predictions[n] exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n' f'target: {tgt_txt}\n\n' f'prediction: {pred_txt}\n\n') eval_summary_writer.text(f'{task.name} samples', exemplars, host_step) eval_summary_writer.flush() # Take an Xprof trace after the first loop has compiled everything. if epoch == first_epoch + 1: train_lib.sync_devices() # For on-device loop, we launch the computation before feeding data. logging.info('BEGIN Train loop.') if CFG.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step), epoch) optimizer = train_lib.unbroadcast(optimizer) # Epoch loop. while int(host_step // steps_per_epoch) == epoch: batch = next(train_iter) batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch) # Feed the on-device training loop. if CFG.infeed: for i, device in enumerate(local_devices): # When using infeed to provide data to the computation, we're on our # own for feeding the right values to the right devices. Each device # should get the minibatch corresponding to its replica, a slice of # the larger batch corresponding to the host's replica set. if device.platform == 'tpu': device_coords = (*device.coords, device.id % 2) else: device_coords = (device.host_id, i) per_replica_set_device_coords = tuple( dc % prsm for dc, prsm in zip( device_coords, topology.per_replica_set_mesh)) per_replica_set_replica_coords = tuple( prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords, topology.per_replica_mesh)) per_replica_set_replica_id = 0 for prsm, prm, prsrc in zip( topology.per_replica_set_mesh, topology.per_replica_mesh, per_replica_set_replica_coords): per_replica_set_replica_id = ( per_replica_set_replica_id * prsm // prm + prsrc) input_tuple = tuple([ batch[k][per_replica_set_replica_id] for k in train_keys ]) # Safety check: infeed does not check shape or types but requires # them to agree with on-device spec, otherwise the queue and program # stalls. tuple_shapes = jax.tree_map(jnp.shape, input_tuple) tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple) assert tuple_shapes == device_train_input_shape, ( 'infeed shape error %s != %s' % (tuple_shapes, device_train_input_shape)) assert tuple(set(tuple_dtypes)) == (jnp.int32,), \ ('infeed dtype error %s not all of type %s' % ( tuple_dtypes, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) # Host training loop. else: optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, metrics, dropout_rngs) optimizer = train_lib.unbroadcast(optimizer) host_step += 1 logging.info('END Train loop.') # Maybe save a checkpoint on one host. if (CFG.save_checkpoints and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step) # Gather training metrics. metrics = p_allreduce_metrics(metrics) metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics) denominator = metrics.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics) # pylint: disable=cell-var-from-loop logging.info('train in step: %s, %s', host_step, summary) if jax.host_id() == 0: assert train_summary_writer for key, val in summary.items(): train_summary_writer.scalar(key, val, host_step) train_summary_writer.flush() # Gather training evaluation metrics. logging.info('Gathering training evaluation metrics.') eval_metrics = [] eval_iter = eval_ds.as_numpy_iterator() for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter): eval_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) # average metrics across devices eval_metrics = p_allreduce_metrics(eval_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # average metrics across steps eval_metrics = jax.tree_map(np.sum, eval_metrics) eval_denominator = eval_metrics.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics) logging.info('eval in step: %s, %s', host_step, eval_summary) if jax.host_id() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, host_step) eval_summary_writer.flush() # Wait until computations are done before exiting logging.info('Finished.') train_lib.sync_devices() # Shut down the infeed threadpool. if CFG.infeed: infeed_pool.shutdown()