def test_synchronize_multiple_hosts(self, process_index_mock): multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test") state = TrainState(step=1) process_index_mock.return_value = 0 ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir) process_index_mock.return_value = 1 ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir) # Initialize both at step=1. state_0 = ckpt_0.restore_or_initialize(state) state_1 = ckpt_1.restore_or_initialize(state) # Update both at step=2. state_0 = state_0.replace(step=2) ckpt_0.save(state_0) state_1 = state_1.replace(step=2) ckpt_1.save(state_1) # Update ckpt_1 at step=3. state_1 = state_1.replace(step=3) ckpt_1.save(state_1) # Reload both at step=2. process_index_mock.return_value = 0 ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir) process_index_mock.return_value = 1 ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir) self.assertEqual(ckpt_0.latest_checkpoint, ckpt_0.get_latest_checkpoint_to_restore_from()) self.assertNotEqual(ckpt_1.latest_checkpoint, ckpt_1.get_latest_checkpoint_to_restore_from()) state_0 = ckpt_0.restore_or_initialize(state) state_1 = ckpt_1.restore_or_initialize(state) self.assertEqual(state_0.step, 2) self.assertEqual(state_1.step, 2)
def test_initialize_mkdir(self, process_index_mock): multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test") state = TrainState(step=1) process_index_mock.return_value = 0 base_dir = f"{multihost_base_dir}-0" ckpt = checkpoint.MultihostCheckpoint(multihost_base_dir) self.assertIsNone(ckpt.latest_checkpoint) self.assertFalse(os.path.isdir(base_dir)) state = ckpt.restore_or_initialize(state) self.assertIsNotNone(ckpt.latest_checkpoint) self.assertTrue(os.path.isdir(base_dir))
def test_preemption(self): multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test") state = TrainState(step=1) state0 = state.replace(step=0) ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0) ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1) # Initialize both at step=1. state_0 = ckpt_0.restore_or_initialize(state) state_1 = ckpt_1.restore_or_initialize(state) self.assertEqual(state_0.step, 1) self.assertEqual(state_1.step, 1) # Restore both at step=1. state_0 = ckpt_0.restore_or_initialize(state0) state_1 = ckpt_1.restore_or_initialize(state0) self.assertEqual(state_0.step, 1) self.assertEqual(state_1.step, 1) # Update only ckpt_0 to step=2. state_0 = state_0.replace(step=2) ckpt_0.save(state_0) # Load both checkpoints at last common step=1. ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0) ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1) state_0 = ckpt_0.restore_or_initialize(state) state_1 = ckpt_1.restore_or_initialize(state) self.assertEqual(state_0.step, 1) self.assertEqual(state_1.step, 1) # Store both at step=2. state_0 = state_0.replace(step=2) state_1 = state_1.replace(step=2) ckpt_0.save(state_0) ckpt_1.save(state_1) # Restore both at step=2. state_0 = ckpt_0.restore_or_initialize(state0) state_1 = ckpt_1.restore_or_initialize(state0) self.assertEqual(state_0.step, 2) self.assertEqual(state_1.step, 2)
def train_and_evaluate(self, workdir): """Runs a training and evaluation loop. Args: workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) config = self.config substeps = config.training.substeps # Learning rate schedule. num_train_steps = config.training.num_train_steps logging.info('num_train_steps=%d', num_train_steps) # Get train state state = self._train_state # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Distribute training. state = flax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 0: writer.write_hparams(dict(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] step = initial_step with metric_writers.ensure_flushes(writer): while step < num_train_steps: # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step + substeps >= num_train_steps with jax.profiler.StepTraceAnnotation('train', step_num=step): inputs = jax.tree_map(np.asarray, next(self._train_iter)) state, outputs = self._update_func(state, inputs) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) new_step = int(state.step[0]) assert new_step == step + substeps step = new_step is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step if step % config.logs.log_loss_every_steps == 0 and not is_eval: def avg_over_substeps(x): assert x.shape[0] == substeps return float(x.mean(axis=0)) # Extract scalars and images. outputs = flax_utils.unreplicate(outputs) outputs = jax.tree_map(avg_over_substeps, outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if is_eval: with report_progress.timed('eval_full'): outputs = self._eval_epoch(params=state.ema_params) outputs = flax_utils.unreplicate(outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if step % config.logs.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Finishing training at step %d', num_train_steps)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ logging.info('Starting training at %s', workdir) tf.io.gfile.makedirs(workdir) if jax.process_index() == 0: with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f: json.dump(config.to_dict(), f, indent=2) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) train_ds, eval_ds = input_pipeline.create_datasets(config.dataset, data_rng) train_iter = iter(train_ds) test_ds = [] for split in config.dataset.test_splits: ds = input_pipeline.create_val_dataset( config.dataset, split, config.dataset.test_per_device_batch_size, config.dataset.test_pad_last_batch) test_ds.append(ds) # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = train_ds.cardinality().numpy() steps_per_epoch = num_train_steps // config.dataset.num_epochs logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps, steps_per_epoch) learning_rate_fn = functools.partial( train_utils.get_learning_rate, base_learning_rate=config.learning_rate, num_train_steps=num_train_steps, schedule_type=config.learning_rate_schedule, warmup_proportion=config.warmup_proportion, step_boundaries=config.learning_rate_step_boundaries) # Initialize model. inputs = train_utils.get_init_inputs(train_ds) rng, model_rng = jax.random.split(rng) eval_config = models.TransformerConfig(**config.model.to_dict()) train_config = eval_config.replace(deterministic=False) model = models.Model(eval_config) state = train_utils.create_train_state(model, config, model_rng, inputs=inputs) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, grad_clip=config.grad_clip), axis_name='batch', donate_argnums=(0, )) p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config), axis_name='batch') writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(train_utils.flatten_config(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile( num_profile_steps=config.num_profile_steps, logdir=workdir) ] rng, train_rngs = jax.random.split(rng) train_rngs = jax.random.fold_in(train_rngs, jax.process_index()) train_rngs = jax.random.split(train_rngs, jax.local_device_count()) train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): is_last_step = step == num_train_steps with jax.profiler.StepTraceContext('train', step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics = p_train_step(batch=batch, rng=train_rngs, state=state) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) if config.log_loss_every_steps > 0 and ( step % config.log_loss_every_steps == 0 or is_last_step): train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() train_summary = train_utils.metrics_summary( train_metrics, 'train') train_summary['learning_rate'] = lr writer.write_scalars(step, train_summary) train_metrics = [] if config.eval_every_steps > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('eval'): eval_summary = evaluate(p_eval_step, state, eval_ds, config.num_eval_steps) writer.write_scalars(step, eval_summary) if config.checkpoint_every_steps > 0 and ( step % config.checkpoint_every_steps == 0 or is_last_step): with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Checkpoint saved to %s', checkpoint_dir) logging.info('Finishing training at step %d', num_train_steps)
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning( "%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState(step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer(batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like( previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError( "A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError( "A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats[ "patch_representations"] writer.write_histograms( step, {"patch_representations": patch_representations}) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) splits = input_pipeline.create_datasets(config, data_rng) num_classes = splits.info.features["label"].num_classes train_iter = iter(splits.train) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = splits.train.cardinality().numpy() steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 32 but scale it according to our batch size. global_batch_size = config.per_device_batch_size * jax.device_count() base_learning_rate = config.learning_rate * global_batch_size / 32.0 learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) # Initialize model. rng, model_rng = jax.random.split(rng) model, state = create_train_state( config, model_rng, input_shape=splits.train.element_spec["input"].shape[1:], num_classes=num_classes) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, {"train_iter": train_iter}, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Count number of trainable parameters. This must be done before replicating # the state to avoid double-counting replicated parameters. param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target)) # Distribute training over local devices. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.weight_decay), axis_name=_PMAP_AXIS_NAME) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if initial_step == 1: writer.write_hparams(dict(config)) # Log the number of trainable params. writer.write_scalars(initial_step, {"param_count": param_count}) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics = None with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceContext("train", step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics_update = p_train_step(state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None # When combining train and eval, we do not evaluate while training. if ((step % config.eval_every_steps == 0 or is_last_step) and not config.combine_train_val_and_eval_on_test): with report_progress.timed("eval"): eval_metrics = evaluate(model, state, splits.validation, config.num_eval_steps) writer.write_scalars(step, eval_metrics.compute()) if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed("checkpoint"): ckpt.save(flax_utils.unreplicate(state)) if is_last_step and config.combine_train_val_and_eval_on_test: # Evaluate a single time on the test set when requested. with report_progress.timed("test"): test_metrics = evaluate(model, state, splits.test, config.num_eval_steps) writer.write_scalars(step, test_metrics.compute()) logging.info("Finishing training at step %d", num_train_steps)
def predict_and_evaluate(config, workdir, ckpt_path=None): """Runs a testing and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. ckpt_path: The checkpoint to evaluate. If not specified, use the latest checkpoint. """ logging.info('Starting testing at %s', workdir) tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) test_ds = [] for split in config.dataset.test_splits: ds = input_pipeline.create_val_dataset( config.dataset, split, config.dataset.test_per_device_batch_size, config.dataset.test_pad_last_batch) test_ds.append(ds) # Initialize model. inputs = train_utils.get_init_inputs(test_ds[0]) rng, model_rng = jax.random.split(rng) predict_config = models.TransformerConfig(**config.model.to_dict()) predict_config = predict_config.replace(decode=True) model = models.Model(predict_config) state = train_utils.create_train_state(model, config, model_rng, inputs=inputs) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3) logging.info('Testing and evaluating checkpoint %s', ckpt_path) try: state = ckpt.restore(state, ckpt_path) except FileNotFoundError: state = ckpt.restore_or_initialize(state) step = int(state.step) p_pred_step = jax.pmap(functools.partial(predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(3, )) p_init_cache = jax.pmap(functools.partial(init_cache, config=predict_config), axis_name='batch') # Distribute testing. state = flax_utils.replicate(state) with metric_writers.ensure_flushes(writer): test_metrics = {} for ds, split in zip(test_ds, config.dataset.test_splits): ds_metrics = evaluate_sequence_accuracy(p_pred_step, p_init_cache, state, ds, config, split, workdir, config.num_test_steps) ds_metrics = {f'{k}_{split}': v for k, v in ds_metrics.items()} test_metrics.update(ds_metrics) writer.write_scalars(step, test_metrics)