def test_ignores_incomplete_checkpoint(self): base_dir = tempfile.mkdtemp() state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir) # Initializes. state = ckpt.restore_or_initialize(state) state = TrainState(step=0) # Restores step=1. state = ckpt.restore_or_initialize(state) self.assertEqual(state.step, 1) state = TrainState(step=2) # Failed save : step=2 is stored, but TensorFlow checkpoint fails. ckpt.tf_checkpoint_manager.save = None with self.assertRaisesRegex(TypeError, r"'NoneType' object is not callable"): ckpt.save(state) files = os.listdir(base_dir) self.assertIn("ckpt-2.flax", files) self.assertNotIn("ckpt-2.index", files) ckpt = checkpoint.Checkpoint(base_dir) state = TrainState(step=0) # Restores step=1. state = ckpt.restore_or_initialize(state) self.assertEqual(state.step, 1) # Stores step=2. state = TrainState(step=2) path = ckpt.save(state) self.assertEqual(_checkpoint_number(path), 2) files = os.listdir(base_dir) self.assertIn("ckpt-2.flax", files) self.assertIn("ckpt-2.index", files) state = TrainState(step=0) # Restores step=2. state = ckpt.restore_or_initialize(state) self.assertEqual(state.step, 2)
def test_overwrite(self): base_dir = tempfile.mkdtemp() tf_step = tf.Variable(1) state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step)) # Initialize step=1. state = ckpt.restore_or_initialize(state) self.assertEqual(state.step, 1) self.assertEqual(tf_step.numpy(), 1) checkpoint_info = checkpoint.CheckpointInfo.from_path( ckpt.current_checkpoint) # Stores steps 2, 3, 4, 5 for _ in range(4): tf_step.assign_add(1) state = state.replace(step=state.step + 1) ckpt.save(state) latest_checkpoint = str(checkpoint_info._replace(number=5)) self.assertEqual(ckpt.current_checkpoint, latest_checkpoint) self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint) # Restores at step=1 ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step)) state = ckpt.restore(state, checkpoint=str(checkpoint_info)) self.assertEqual(state.step, 1) self.assertEqual(tf_step.numpy(), 1) self.assertNotEqual(ckpt.current_checkpoint, ckpt.latest_checkpoint) self.assertEqual(ckpt.current_checkpoint, str(checkpoint_info)) self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint) # Overwrites step=2, deletes 3, 4, 5. tf_step.assign_add(1) state = state.replace(step=state.step + 1) ckpt.save(state) latest_checkpoint = str(checkpoint_info._replace(number=2)) self.assertEqual(ckpt.current_checkpoint, latest_checkpoint) self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
def test_fails_if_save_counter_mismatch(self): base_dir = tempfile.mkdtemp() ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1) state = TrainState(step=1) state = ckpt.restore_or_initialize(state) ckpt.save(state) ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1) state = TrainState(step=2) with self.assertRaisesRegexp(RuntimeError, r"^Expected.*to match"): ckpt.save(state)
def test_restore_flax_alone(self): base_dir = tempfile.mkdtemp() ds_iter = iter(_make_dataset()) ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter)) state = TrainState(step=1) # Initializes. state = ckpt.restore_or_initialize(state) state = TrainState(step=0) ckpt = checkpoint.Checkpoint(base_dir) # Restores step=1. state = ckpt.restore_or_initialize(state) self.assertEqual(state.step, 1)
def test_restore_dict(self): base_dir = tempfile.mkdtemp() ds_iter = iter(_make_dataset()) ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter)) with self.assertRaisesRegex(FileNotFoundError, r"No checkpoint found at"): ckpt.restore_dict() with self.assertRaisesRegex(FileNotFoundError, r"Checkpoint invalid does not exist"): ckpt.restore_dict(checkpoint="invalid") state = TrainState(step=1) ckpt.save(state) state_dict = ckpt.restore_dict() self.assertEqual(state_dict, dict(step=1)) first_checkpoint = ckpt.latest_checkpoint new_state = TrainState(step=2) ckpt.save(new_state) self.assertEqual(ckpt.restore_dict(checkpoint=first_checkpoint), dict(step=1)) self.assertEqual(ckpt.restore_dict(), dict(step=2)) self.assertEqual(ckpt.restore_dict(checkpoint=ckpt.latest_checkpoint), dict(step=2))
def test_restores_tf_state(self): base_dir = tempfile.mkdtemp() ds_iter = iter(_make_dataset()) ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter)) features0 = next(ds_iter) # Advance iterator by one. del features0 state = TrainState(step=1) # Initialize at features1. state = ckpt.restore_or_initialize(state) features1 = next(ds_iter) features2 = next(ds_iter) self.assertNotAllEqual(features1["x"], features2["x"]) self.assertNotAllEqual(features1["y"], features2["y"]) # Restore at features1. state = ckpt.restore_or_initialize(state) features1_restored = next(ds_iter) self.assertAllEqual(features1["x"], features1_restored["x"]) self.assertAllEqual(features1["y"], features1_restored["y"]) # Save at features2. path = ckpt.save(state) self.assertEqual(_checkpoint_number(path), 2) features2 = next(ds_iter) features3 = next(ds_iter) self.assertNotAllEqual(features2["x"], features3["x"]) self.assertNotAllEqual(features2["y"], features3["y"]) # Restore at features2. state = ckpt.restore_or_initialize(state) features2_restored = next(ds_iter) self.assertAllEqual(features2["x"], features2_restored["x"]) self.assertAllEqual(features2["y"], features2_restored["y"]) # Restore at features2 as dictionary. state = ckpt.restore_dict() features2_restored = next(ds_iter) self.assertAllEqual(features2["x"], features2_restored["x"]) self.assertAllEqual(features2["y"], features2_restored["y"])
def test_restores_flax_state(self): base_dir = tempfile.mkdtemp() state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=2) # Initializes. state = ckpt.restore_or_initialize(state) state = TrainState(step=0) # Restores step=1. state = ckpt.restore_or_initialize(state) self.assertEqual(state.step, 1) state = TrainState(step=2) # Stores step=2. path = ckpt.save(state) self.assertEqual(_checkpoint_number(path), 2) state = TrainState(step=0) # Restores step=2. state = ckpt.restore(state) self.assertEqual(state.step, 2) state = TrainState(step=3) # Stores step=3 path2 = ckpt.save(state) self.assertEqual(_checkpoint_number(path2), 3) state = TrainState(step=0) # Restores step=2. state = ckpt.restore(state, path) self.assertEqual(state.step, 2)
def test_initialize_mkdir(self): base_dir = os.path.join(tempfile.mkdtemp(), "test") state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir) self.assertIsNone(ckpt.latest_checkpoint) self.assertFalse(os.path.isdir(base_dir)) state = ckpt.restore_or_initialize(state) self.assertIsNotNone(ckpt.latest_checkpoint) self.assertTrue(os.path.isdir(base_dir))
def test_fails_when_restoring_superset(self): base_dir = tempfile.mkdtemp() ckpt = checkpoint.Checkpoint(base_dir) state = TrainState(step=0) # Initialixes with TrainState. state = ckpt.restore_or_initialize(state) state = TrainStateExtended(step=1, name="test") # Restores with TrainStateExtended. with self.assertRaisesRegex(ValueError, r"^Missing field"): state = ckpt.restore_or_initialize(state)
def test_load_state_dict(self): base_dir = tempfile.mkdtemp() state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir) # Initializes. state = ckpt.restore_or_initialize(state) # Load via load_state_dict(). flax_dict = checkpoint.load_state_dict(base_dir) self.assertEqual(flax_dict, dict(step=1)) with self.assertRaisesRegexp(FileNotFoundError, r"^No checkpoint found"): checkpoint.load_state_dict(tempfile.mkdtemp())
def test_max_to_keep(self): base_dir = tempfile.mkdtemp() state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1) state = ckpt.restore_or_initialize(state) files1 = os.listdir(base_dir) state = TrainState(step=2) path = ckpt.save(state) self.assertEqual(_checkpoint_number(path), 2) files2 = os.listdir(base_dir) self.assertEqual(len(files1), len(files2)) self.assertNotEqual(files1, files2)
def main(_): jax.config.update('jax_enable_x64', True) config: config_dict.ConfigDict = _CONFIG.value logging.info(config) key = jax.random.PRNGKey(config.seed) key, psi_key, phi_key = jax.random.split(key, 3) Psi = jax.random.normal(psi_key, (config.S, config.T), dtype=jnp.float64) Phi = jax.random.normal(phi_key, (config.S, config.d), dtype=jnp.float64) # Wrap feature matrix in np array to allow for indexing. Phi = np.array(Phi) chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value) initial_step = 0 initial_step, Phi = chkpt_manager.restore_or_initialize( (initial_step, Phi)) optimal_subspace = compute_optimal_subspace(Psi, config.d) workdir = epath.Path(_WORKDIR.value) workdir.mkdir(exist_ok=True) Phis = train(workdir=workdir, initial_step=initial_step, chkpt_manager=chkpt_manager, Phi=Phi, Psi=Psi, optimal_subspace=optimal_subspace, num_epochs=config.num_epochs, learning_rate=config.lr, key=key, method=config.method, lissa_kappa=config.kappa, optimizer=config.optimizer, covariance_batch_size=config.covariance_batch_size, main_batch_size=config.main_batch_size, weight_batch_size=config.weight_batch_size, estimate_feature_norm=config.estimate_feature_norm) with (workdir / 'phis.pkl').open('wb') as fout: pickle.dump(Phis, fout, protocol=4)
def evaluate(self, workdir, dir_name='eval', ckpt_name=None): """Perform one evaluation.""" checkpoint_dir = os.path.join(workdir, 'checkpoints-0') ckpt = checkpoint.Checkpoint(checkpoint_dir) state_dict = ckpt.restore_dict(os.path.join(checkpoint_dir, ckpt_name)) ema_params = flax.core.FrozenDict(state_dict['ema_params']) step = int(state_dict['step']) # Distribute training. ema_params = flax_utils.replicate(ema_params) eval_logdir = os.path.join(workdir, dir_name) tf.io.gfile.makedirs(eval_logdir) writer = metric_writers.create_default_writer( eval_logdir, just_logging=jax.process_index() > 0) outputs = self._eval_epoch(params=ema_params) outputs = flax_utils.unreplicate(outputs) scalars, images = outputs['scalars'], outputs['images'] writer.write_scalars(step, scalars) writer.write_images(step, images)
def evaluate(base_dir, config, *, train_state): """Eval function.""" chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval')) writer = create_default_writer() key = jax.random.PRNGKey(config.eval.seed) model_init_key, ds_key = jax.random.split(key) linear_module = LinearModule(config.eval.num_tasks) params = linear_module.init(model_init_key, jnp.zeros((config.encoder.embedding_dim, ))) lr = optax.cosine_decay_schedule(config.eval.learning_rate, config.num_eval_steps) optim = optax.adam(lr) ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks) ds_iter = iter(ds) state = TrainState.create(apply_fn=linear_module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_eval_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = EvalMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_eval_steps)): with jax.profiler.StepTraceAnnotation('eval', step_num=step): states, targets = next(ds_iter) state, metrics = evaluate_step(train_state, state, metrics, states, targets) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = EvalMetrics.empty() for hook in hooks: hook(step) # Finally, evaluate on the true(ish) test aux task matrix. states, targets = dataset.EvalDataset(config, ds_key).get_batch() @jax.jit def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets)) test_loss = loss_fn() writer.write_scalars(config.num_eval_steps + 1, {'test_loss': test_loss})
def test_fails_if_not_registered(self): base_dir = tempfile.mkdtemp() not_state = NotTrainState() ckpt = checkpoint.Checkpoint(base_dir) with self.assertRaisesRegex(TypeError, r"serialize"): ckpt.restore_or_initialize(not_state)
def test_checkpoint_name(self): base_dir = tempfile.mkdtemp() state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir, checkpoint_name="test") path = ckpt.save(state) self.assertIn("test", path)
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ # Seed for reproducibility. rng = jax.random.PRNGKey(config.rng_seed) # Set up logging. summary_writer = metric_writers.create_default_writer(workdir) summary_writer.write_hparams(dict(config)) # Get datasets. rng, dataset_rng = jax.random.split(rng) dataset = input_pipeline.get_dataset(config, dataset_rng) graph, labels, masks = jax.tree_map(jnp.asarray, dataset) labels = jax.nn.one_hot(labels, config.num_classes) train_mask = masks['train'] train_indices = jnp.where(train_mask)[0] train_labels = labels[train_indices] num_training_nodes = len(train_indices) # Get subgraphs. if config.differentially_private_training: graph = jax.tree_map(np.asarray, graph) subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to) graph = jax.tree_map(jnp.asarray, graph) # We only need the subgraphs for training nodes. train_subgraphs = subgraphs[train_indices] del subgraphs else: train_subgraphs = None # Initialize privacy accountant. training_privacy_accountant = privacy_accountants.get_training_privacy_accountant( config, num_training_nodes, compute_max_terms_per_node(config)) # Construct and initialize model. rng, init_rng = jax.random.split(rng) estimation_indices = get_estimation_indices(train_indices, config) state = create_train_state(init_rng, config, graph, train_labels, train_subgraphs, estimation_indices) # Set up checkpointing of the model. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Log overview of parameters. parameter_overview.log_parameter_overview(state.params) # Log metrics after initialization. logits = compute_logits(state, graph) metrics_after_init = compute_metrics(logits, labels, masks) metrics_after_init['epsilon'] = 0 log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init') # Train model. rng, train_rng = jax.random.split(rng) max_training_epsilon = get_max_training_epsilon(config) # Hooks called periodically during training. report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_training_steps, writer=summary_writer) profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) hooks = [report_progress, profiler] for step in range(initial_step, config.num_training_steps): # Perform one step of training. with jax.profiler.StepTraceAnnotation('train', step_num=step): # Sample batch. step_rng = jax.random.fold_in(train_rng, step) indices = jax.random.choice(step_rng, num_training_nodes, (config.batch_size, )) # Compute gradients. if config.differentially_private_training: grads = compute_updates_for_dp(state, graph, train_labels, train_subgraphs, indices, config.adjacency_normalization) else: grads = compute_updates(state, graph, train_labels, indices) # Update parameters. state = update_model(state, grads) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 10, step) for hook in hooks: hook(step) # Evaluate, if required. is_last_step = (step == config.num_training_steps - 1) if step % config.evaluate_every_steps == 0 or is_last_step: with report_progress.timed('eval'): # Check if privacy budget exhausted. training_epsilon = training_privacy_accountant(step + 1) if max_training_epsilon is not None and training_epsilon >= max_training_epsilon: break # Compute metrics. logits = compute_logits(state, graph) metrics_during_training = compute_metrics( logits, labels, masks) metrics_during_training['epsilon'] = training_epsilon log_metrics(step, metrics_during_training, summary_writer) # Checkpoint, if required. if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(state) return state
def train(*, workdir, compute_phi, compute_psi, params, optimal_subspace, num_epochs, learning_rate, key, method, lissa_kappa, optimizer, covariance_batch_size, main_batch_size, weight_batch_size, d, num_tasks, compute_feature_norm_on_oracle_states, sample_states, eval_states, use_tabular_gradient=True): """Training function. For lissa, the total number of samples is 2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size. Args: workdir: Work directory, where we'll save logs. compute_phi: A function that takes params and states and returns a matrix of phis. compute_psi: A function that takes an array of states and an array of tasks and returns Psi[states, tasks]. params: Parameters used as the first argument for compute_phi. optimal_subspace: Top-d left singular vectors of Psi. num_epochs: How many gradient steps to perform. (Not really epochs) learning_rate: The step size parameter for sgd. key: The jax prng key. method: 'naive', 'lissa', or 'oracle'. lissa_kappa: The parameter of the lissa method, if used. optimizer: Which optimizer to use. Only 'sgd' is supported. covariance_batch_size: the 'J' parameter. For the naive method, this is how many states we sample to construct the inverse. For the lissa method, ditto -- these are also "iterations". main_batch_size: How many states to update at once. weight_batch_size: How many states to construct the weight vector. d: The dimension of the representation. num_tasks: The total number of tasks. compute_feature_norm_on_oracle_states: If True, computes the feature norm using the oracle states (all the states in synthetic experiments). Otherwise, computes the norm using the sampled batch. Only applies to LISSA. sample_states: A function that takes an rng key and a number of states to sample, and returns a tuple containing (a vector of sampled states, an updated rng key). eval_states: An array of states to use to compute metrics on. This will be used to compute Phi = compute_phi(params, eval_states). use_tabular_gradient: If true, the train step will calculate the gradient using the tabular calculation. Otherwise, it will use a jax.vjp to backpropagate the gradient. """ # Create an explicit weight vector (needed for explicit method only). if method == 'explicit': key, weight_key = jax.random.split(key) explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks), dtype=jnp.float32) params['explicit_weight_matrix'] = explicit_weight_matrix if optimizer == 'sgd': optimizer = optax.sgd(learning_rate) elif optimizer == 'adam': optimizer = optax.adam(learning_rate) else: raise ValueError(f'Unknown optimizer {optimizer}.') optimizer_state = optimizer.init(params) chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value) initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize( (0, params, optimizer_state)) writer = metric_writers.create_default_writer(logdir=str(workdir), ) # Checkpointing and logging too much can use a lot of disk space. # Therefore, we don't want to checkpoint more than 10 times an experiment, # or keep more than 1k Phis per experiment. checkpoint_period = max(num_epochs // 10, 100_000) log_period = max(1_000, num_epochs // 1_000) def _checkpoint_callback(step, t, params, optimizer_state): del t # Unused. chkpt_manager.save((step, params, optimizer_state)) hooks = [ periodic_actions.PeriodicCallback(every_steps=checkpoint_period, callback_fn=_checkpoint_callback) ] fixed_train_kwargs = { 'compute_phi': compute_phi, 'compute_psi': compute_psi, 'optimizer': optimizer, 'method': method, # In the tabular case, the eval_states are all the states. 'oracle_states': eval_states, 'lissa_kappa': lissa_kappa, 'main_batch_size': main_batch_size, 'covariance_batch_size': covariance_batch_size, 'weight_batch_size': weight_batch_size, 'd': d, 'num_tasks': num_tasks, 'compute_feature_norm_on_oracle_states': (compute_feature_norm_on_oracle_states), 'sample_states': sample_states, 'use_tabular_gradient': use_tabular_gradient, } variable_kwargs = { 'params': params, 'optimizer_state': optimizer_state, 'key': key, } @jax.jit def _eval_step(phi_params): eval_phi = compute_phi(phi_params, eval_states) eval_psi = compute_psi(eval_states) # pytype: disable=wrong-arg-count metrics = compute_metrics(eval_phi, optimal_subspace) metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)} return metrics # Perform num_epochs gradient steps. with metric_writers.ensure_flushes(writer): for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1), initial=initial_step, total=num_epochs): variable_kwargs = _train_step(**fixed_train_kwargs, **variable_kwargs) if step % log_period == 0: metrics = _eval_step(variable_kwargs['params']['phi_params']) writer.write_scalars(step, metrics) for hook in hooks: hook(step, params=variable_kwargs['params'], optimizer_state=variable_kwargs['optimizer_state']) writer.flush()
def train(base_dir, config): """Train function.""" print(config) chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train')) writer = create_default_writer() # Initialize dataset key = jax.random.PRNGKey(config.seed) key, subkey = jax.random.split(key) ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks) ds_iter = iter(ds) key, subkey = jax.random.split(key) encoder = MLPEncoder(**config.encoder) train_config = config.train.to_dict() train_method = train_config.pop('method') module_config = train_config.pop('module') module_class = module_config.pop('name') module = globals().get(module_class)(encoder, **module_config) train_step = globals().get(f'train_step_{train_method}') train_step = functools.partial(train_step, **train_config) params = module.init(subkey, next(ds_iter)[0]) lr = optax.cosine_decay_schedule(config.learning_rate, config.num_train_steps) optim = optax.chain(optax.adam(lr), # optax.adaptive_grad_clip(0.15) ) state = TrainState.create(apply_fn=module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) # Hooks report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = TrainMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_train_steps)): with jax.profiler.StepTraceAnnotation('train', step_num=step): states, targets = next(ds_iter) state, metrics = train_step(state, metrics, states, targets) logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = TrainMetrics.empty() # if step % config.log_eval_metrics_every == 0 and isinstance( # ds, dataset.MDPDataset): # eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config) # writer.write_scalars(step, eval_metrics.compute()) for hook in hooks: hook(step) chkpt_manager.save(state) return state