Exemplo n.º 1
0
def _get_temp_gcs_path():
    path = f"gs://{GCS_TEST_BUCKET}/" + "".join(
        random.choice(string.ascii_lowercase) for i in range(16)
    )
    gfile.mkdir(path)
    yield path + "/file.name"
    gfile.rmtree(path)
 def __init__(self, logdir, flush_secs=2, is_dummy=False, dummy_time=None):
     self._name_to_tf_name = {}
     self._tf_names = set()
     self.is_dummy = is_dummy
     self.logdir = logdir
     self.flush_secs = flush_secs  # TODO
     self._writer = None
     self._dummy_time = dummy_time
     if is_dummy:
         self.dummy_log = defaultdict(list)
     else:
         '''
         if not os.path.exists(self.logdir):
             os.makedirs(self.logdir)
         '''
         gfile.mkdir(self.logdir)
         hostname = socket.gethostname()
         filename = os.path.join(
             self.logdir,
             'events.out.tfevents.{}.{}'.format(int(self._time()),
                                                hostname))
         #self._writer = open(filename, 'wb')
         self._writer = gfile.GFile(filename, 'wb')
         self._write_event(
             event_pb2.Event(wall_time=self._time(),
                             step=0,
                             file_version='brain.Event:2'))
Exemplo n.º 3
0
    def _make_root_class_dirs(self):
        """Make dataset root dir and subdir for each class."""
        if not gfile.exists(self.new_dataset_dir):
            gfile.makedirs(self.new_dataset_dir)

        for class_name in self.fg_classes:
            class_dir = path.join(self.new_dataset_dir, class_name, '')
            if not gfile.exists(class_dir):
                gfile.mkdir(class_dir)
Exemplo n.º 4
0
def _make_directory(path):
  """Helper function: create directory if it doesn't exist yet."""
  if not gfile.exists(path):
    log("Creating directory %s" % path)
    gfile.mkdir(path)
Exemplo n.º 5
0
def mkdir(path: str):
    gfile.mkdir(path)
Exemplo n.º 6
0
def eval_checkpoints(
    checkpoint_dir,
    hps,
    rng,
    eval_num_batches,
    model_cls,
    dataset_builder,
    dataset_meta_data,
    hessian_eval_config,
    min_global_step=None,
    max_global_step=None,
):
  """Evaluate the Hessian of the given checkpoints.

  Iterates over all checkpoints in the specified directory, loads the checkpoint
  then evaluates the Hessian on the given checkpoint. A list of dicts will be
  saved to cns at checkpoint_dir/hessian_eval_config['name'].

  Args:
    checkpoint_dir: Directory of checkpoints to load.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_num_batches: (int) The batch size used for evaluating on
      validation, and test sets. Set to None to evaluate on the whole test set.
    model_cls: One of the model classes (not an instance) defined in model_lib.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    dataset_meta_data: dict of meta_data about the dataset.
    hessian_eval_config: a dict specifying the configuration of the Hessian
      eval.
    min_global_step: Lower bound on what steps to filter checkpoints. Set to
      None to evaluate all checkpoints in the directory.
    max_global_step: Upper bound on what steps to filter checkpoints.
  """
  rng, init_rng = jax.random.split(rng)
  rng = jax.random.fold_in(rng, jax.process_index())
  rng, data_rng = jax.random.split(rng)

  initializer = initializers.get_initializer('noop')

  loss_name = 'cross_entropy'
  metrics_name = 'classification_metrics'
  model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

  # Maybe run the initializer.
  unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
      model.flax_module,
      initializer, model.loss_fn,
      hps.input_shape,
      hps.output_shape, hps, init_rng,
      None)

  # Fold in a the unreplicated batch_stats and rng into the loss used by
  # hessian eval.
  def batch_loss(params, batch_rng):
    batch, rng = batch_rng
    return model.training_cost(
        params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0]
  batch_stats = jax_utils.replicate(unreplicated_batch_stats)

  if jax.process_index() == 0:
    utils.log_pytree_shape_and_statistics(unreplicated_params)
    logging.info('train_size: %d,', hps.train_size)
    logging.info(hps)
    # Save the hessian computation hps to the experiment directory
    exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    if not gfile.exists(exp_dir):
      gfile.mkdir(exp_dir)
    if min_global_step == 0:
      hparams_fname = os.path.join(exp_dir, 'hparams.json')
      with gfile.GFile(hparams_fname, 'w') as f:
        f.write(hps.to_json())
      config_fname = os.path.join(exp_dir, 'hconfig.json')
      with gfile.GFile(config_fname, 'w') as f:
        f.write(json.dumps(hessian_eval_config))

  optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(hps)
  unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)
  # Note that we do not use the learning rate.
  # The optimizer state is a list of all the optax transformation states, and
  # we inject the learning rate into all states that will accept it.
  for state in unreplicated_optimizer_state:
    if (isinstance(state, optax.InjectHyperparamsState) and
        'learning_rate' in state.hyperparams):
      state.hyperparams['learning_rate'] = jax_utils.replicate(1.0)
  optimizer_state = jax_utils.replicate(unreplicated_optimizer_state)
  params = jax_utils.replicate(unreplicated_params)
  data_rng = jax.random.fold_in(data_rng, 0)

  assert hps.batch_size % (jax.device_count()) == 0
  dataset = dataset_builder(
      data_rng,
      hps.batch_size,
      eval_batch_size=hps.batch_size,  # eval iterators not used.
      hps=hps,
  )

  # pmap functions for the training loop
  evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')

  if jax.process_index() == 0:
    logging.info('Starting eval!')
    logging.info('Number of hosts: %d', jax.process_count())

  hessian_evaluator = hessian_eval.CurvatureEvaluator(
      params,
      hessian_eval_config,
      dataset=dataset,
      loss=batch_loss)
  if min_global_step is None:
    suffix = ''
  else:
    suffix = '{}_{}'.format(min_global_step, max_global_step)
  pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'],
                             suffix)
  logger = utils.MetricLogger(pytree_path=pytree_path)
  for checkpoint_path, step in iterate_checkpoints(checkpoint_dir,
                                                   min_global_step,
                                                   max_global_step):
    unreplicated_checkpoint_state = dict(
        params=unreplicated_params,
        optimizer_state=unreplicated_optimizer_state,
        batch_stats=unreplicated_batch_stats,
        global_step=0,
        preemption_count=0,
        sum_train_cost=0.0)
    ckpt = checkpoint.load_checkpoint(
        checkpoint_path,
        target=unreplicated_checkpoint_state)
    results, _ = checkpoint.replicate_checkpoint(
        ckpt,
        pytree_keys=['params', 'optimizer_state', 'batch_stats'])
    params = results['params']
    optimizer_state = results['optimizer_state']
    batch_stats = results['batch_stats']
    # pylint: disable=protected-access
    batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
    # pylint: enable=protected-access
    report, _ = trainer.eval_metrics(params, batch_stats, dataset,
                                     eval_num_batches, eval_num_batches,
                                     evaluate_batch_pmapped)
    if jax.process_index() == 0:
      logging.info('Global Step: %d', step)
      logging.info(report)
    row = {}
    grads, updates = [], []
    hess_evecs, cov_evecs = [], []
    stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum(
        params, step)
    row.update(stats)
    if hessian_eval_config[
        'compute_stats'] or hessian_eval_config['compute_interps']:
      grads, updates = hessian_evaluator.compute_dirs(
          params, optimizer_state, optimizer_update_fn)
    row.update(hessian_evaluator.evaluate_stats(params, grads,
                                                updates, hess_evecs,
                                                cov_evecs, step))
    row.update(hessian_evaluator.compute_interpolations(params, grads,
                                                        updates, hess_evecs,
                                                        cov_evecs, step))
    if jax.process_index() == 0:
      logger.append_pytree(row)