Example #1
0
    def testAppendPytree(self):
        """Test appending and loading pytrees."""
        pytrees = [{'a': i} for i in range(10)]
        pytree_path = os.path.join(self.test_dir, 'pytree.ckpt')
        logger = utils.MetricLogger(pytree_path=pytree_path)

        for pytree in pytrees:
            logger.append_pytree(pytree)

        latest = checkpoint.load_latest_checkpoint(pytree_path)
        saved_pytrees = latest.pytree if latest else []
        self.assertEqual(pytrees, saved_pytrees)
Example #2
0
def set_up_loggers(train_dir,
                   xm_work_unit=None,
                   use_deprecated_checkpointing=True):
    """Creates a logger for eval metrics as well as initialization metrics."""
    csv_path = os.path.join(train_dir, 'measurements.csv')
    pytree_path = os.path.join(train_dir, 'training_metrics')
    metrics_logger = utils.MetricLogger(
        csv_path=csv_path,
        pytree_path=pytree_path,
        xm_work_unit=xm_work_unit,
        events_dir=train_dir,
        use_deprecated_checkpointing=use_deprecated_checkpointing)

    init_csv_path = os.path.join(train_dir, 'init_measurements.csv')
    init_json_path = os.path.join(train_dir, 'init_scalars.json')
    init_logger = utils.MetricLogger(
        csv_path=init_csv_path,
        json_path=init_json_path,
        xm_work_unit=xm_work_unit,
        use_deprecated_checkpointing=use_deprecated_checkpointing)
    return metrics_logger, init_logger
    def __init__(self, model, optimizer, batch_stats, optimizer_state, dataset,
                 hps, callback_config, train_dir, rng):
        del hps
        del rng
        del optimizer
        del batch_stats
        del optimizer_state
        del callback_config  # In future CL's we will use this.
        checkpoint_dir = os.path.join(train_dir, 'checkpoints')
        # copy batch_stats as we close over it, and it gets modified.
        self.dataset = dataset
        checkpoint_dir = os.path.join(train_dir, 'checkpoints')
        pytree_path = os.path.join(checkpoint_dir, 'debugger')
        logger = utils.MetricLogger(pytree_path=pytree_path)

        get_act_stats_fn = model_debugger.create_forward_pass_stats_fn(
            model.flax_module, capture_activation_norms=True)

        debugger = model_debugger.ModelDebugger(use_pmap=True,
                                                forward_pass=get_act_stats_fn,
                                                metrics_logger=logger)
        # pmap functions for the training loop
        # in_axes = (optimizer = 0, batch_stats = 0, batch = 0, step = None,
        # lr = None, rng = None, local_device_index = 0, training_metrics_grabber=0,
        # training_metrics_grabber, training_cost )
        # Also, we can donate buffers for 'optimizer', 'batch_stats',
        # 'batch' and 'training_metrics_grabber' for update's pmapped computation.

        self.get_stats_pmapped = jax.pmap(functools.partial(
            get_stats,
            training_cost=model.training_cost,
        ),
                                          axis_name='batch',
                                          in_axes=(0, 0, 0, None, None, 0))
        self.debugger = debugger
        self.logger = logger
        self.dataset = dataset

        batch = next(dataset.train_iterator_fn())
        self.batch = data_utils.shard(batch)
Example #4
0
def set_up_hessian_eval(model, flax_module, batch_stats, dataset,
                        checkpoint_dir, hessian_eval_config):
    """Builds the CurvatureEvaluator object."""

    # First copy then unreplicate batch_stats. Note batch_stats doesn't affect the
    # forward pass in the hessian eval because we always run the model in training
    # However, we need to provide batch_stats for the model.training_cost API.
    # The copy is needed b/c the trainer will modify underlying arrays.
    batch_stats = jax.tree_map(lambda x: x[:][0], batch_stats)

    def batch_loss(module, batch_rng):
        batch, rng = batch_rng
        return model.training_cost(module,
                                   batch,
                                   batch_stats=batch_stats,
                                   dropout_rng=rng)[0]

    pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    logger = utils.MetricLogger(pytree_path=pytree_path)
    hessian_evaluator = hessian_eval.CurvatureEvaluator(flax_module,
                                                        hessian_eval_config,
                                                        dataset=dataset,
                                                        loss=batch_loss)
    return hessian_evaluator, logger