def testAppendPytree(self): """Test appending and loading pytrees.""" pytrees = [{'a': i} for i in range(10)] pytree_path = os.path.join(self.test_dir, 'pytree.ckpt') logger = utils.MetricLogger(pytree_path=pytree_path) for pytree in pytrees: logger.append_pytree(pytree) latest = checkpoint.load_latest_checkpoint(pytree_path) saved_pytrees = latest.pytree if latest else [] self.assertEqual(pytrees, saved_pytrees)
def set_up_loggers(train_dir, xm_work_unit=None, use_deprecated_checkpointing=True): """Creates a logger for eval metrics as well as initialization metrics.""" csv_path = os.path.join(train_dir, 'measurements.csv') pytree_path = os.path.join(train_dir, 'training_metrics') metrics_logger = utils.MetricLogger( csv_path=csv_path, pytree_path=pytree_path, xm_work_unit=xm_work_unit, events_dir=train_dir, use_deprecated_checkpointing=use_deprecated_checkpointing) init_csv_path = os.path.join(train_dir, 'init_measurements.csv') init_json_path = os.path.join(train_dir, 'init_scalars.json') init_logger = utils.MetricLogger( csv_path=init_csv_path, json_path=init_json_path, xm_work_unit=xm_work_unit, use_deprecated_checkpointing=use_deprecated_checkpointing) return metrics_logger, init_logger
def __init__(self, model, optimizer, batch_stats, optimizer_state, dataset, hps, callback_config, train_dir, rng): del hps del rng del optimizer del batch_stats del optimizer_state del callback_config # In future CL's we will use this. checkpoint_dir = os.path.join(train_dir, 'checkpoints') # copy batch_stats as we close over it, and it gets modified. self.dataset = dataset checkpoint_dir = os.path.join(train_dir, 'checkpoints') pytree_path = os.path.join(checkpoint_dir, 'debugger') logger = utils.MetricLogger(pytree_path=pytree_path) get_act_stats_fn = model_debugger.create_forward_pass_stats_fn( model.flax_module, capture_activation_norms=True) debugger = model_debugger.ModelDebugger(use_pmap=True, forward_pass=get_act_stats_fn, metrics_logger=logger) # pmap functions for the training loop # in_axes = (optimizer = 0, batch_stats = 0, batch = 0, step = None, # lr = None, rng = None, local_device_index = 0, training_metrics_grabber=0, # training_metrics_grabber, training_cost ) # Also, we can donate buffers for 'optimizer', 'batch_stats', # 'batch' and 'training_metrics_grabber' for update's pmapped computation. self.get_stats_pmapped = jax.pmap(functools.partial( get_stats, training_cost=model.training_cost, ), axis_name='batch', in_axes=(0, 0, 0, None, None, 0)) self.debugger = debugger self.logger = logger self.dataset = dataset batch = next(dataset.train_iterator_fn()) self.batch = data_utils.shard(batch)
def set_up_hessian_eval(model, flax_module, batch_stats, dataset, checkpoint_dir, hessian_eval_config): """Builds the CurvatureEvaluator object.""" # First copy then unreplicate batch_stats. Note batch_stats doesn't affect the # forward pass in the hessian eval because we always run the model in training # However, we need to provide batch_stats for the model.training_cost API. # The copy is needed b/c the trainer will modify underlying arrays. batch_stats = jax.tree_map(lambda x: x[:][0], batch_stats) def batch_loss(module, batch_rng): batch, rng = batch_rng return model.training_cost(module, batch, batch_stats=batch_stats, dropout_rng=rng)[0] pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name']) logger = utils.MetricLogger(pytree_path=pytree_path) hessian_evaluator = hessian_eval.CurvatureEvaluator(flax_module, hessian_eval_config, dataset=dataset, loss=batch_loss) return hessian_evaluator, logger