Пример #1
0
    def evaluate(self, global_step, rng, **unused_args):
        """See base class."""
        global_step = np.array(jl_utils.get_first(global_step))
        scalars = jax.device_get(self._eval_epoch(jl_utils.get_first(rng)))

        logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
        return scalars
Пример #2
0
    def step(self, global_step, rng, *unused_args, **unused_kwargs):
        # Get next inputs.
        supervised_inputs = next(self.supervised_train_input)
        if self.extra_train_input is None:
            extra_inputs = None
        else:
            extra_inputs = next(self.extra_train_input)

        # Perform step.
        (self._params, self._avg_params, self._state, self._opt_state,
         scalars) = self.train_fn(params=self._params,
                                  avg_params=self._avg_params,
                                  state=self._state,
                                  opt_state=self._opt_state,
                                  global_step=global_step,
                                  supervised_inputs=supervised_inputs,
                                  extra_inputs=extra_inputs,
                                  rng=rng)
        scalars = jl_utils.get_first(scalars)

        # Save final checkpoint.
        if self.config.save_final_checkpoint_as_npy and not self.config.dry_run:
            global_step_value = jl_utils.get_first(global_step)
            if global_step_value == FLAGS.config.get('training_steps', 1) - 1:
                f_np = lambda x: np.array(jax.device_get(jl_utils.get_first(x))
                                          )
                np_params = jax.tree_map(f_np, self._avg_params
                                         or self._params)
                np_state = jax.tree_map(f_np, self._state)
                path_npy = os.path.join(FLAGS.config.checkpoint_dir,
                                        'checkpoint.npy')
                with tf.io.gfile.GFile(path_npy, 'wb') as fp:
                    np.save(fp, (np_params, np_state))
                logging.info('Saved final checkpoint at %s', path_npy)

        # Run synchronous evaluation.
        if self.config.evaluation.interval <= 0:
            return scalars

        global_step_value = jl_utils.get_first(global_step)
        if (global_step_value % self.config.evaluation.interval != 0
                and global_step_value !=
                FLAGS.config.get('training_steps', 1) - 1):
            return _merge_eval_scalars(scalars, self._last_evaluation_scalars)
        logging.info('Running synchronous evaluation...')
        eval_scalars = self.evaluate(global_step, rng)
        f_list = lambda x: x.tolist() if isinstance(x, jnp.ndarray) else x
        self._last_evaluation_scalars = jax.tree_map(f_list, eval_scalars)
        logging.info('(eval) global_step: %d, %s', global_step_value,
                     self._last_evaluation_scalars)
        return _merge_eval_scalars(scalars, self._last_evaluation_scalars)
Пример #3
0
    def step(
        self,
        global_step: jnp.ndarray,
        rng: jnp.ndarray,
        **unused_args,
    ) -> losses.LogsDict:
        """See Jaxline base class."""
        if not self._training:
            self._train_init()

        with jax.profiler.StepTraceAnnotation('next_train_input'):
            batch = next(self._train_input)

        with jax.profiler.StepTraceAnnotation('update_step'):
            (self._params, self._ema_params, self._network_state,
             self._ema_network_state, self._opt_state,
             stats) = self._update_func(
                 self._params,
                 self._ema_params,
                 self._network_state,
                 self._ema_network_state,
                 self._opt_state,
                 global_step,
                 rng,
                 batch,
             )
            del batch  # Buffers donated to _update_func.

        with jax.profiler.StepTraceAnnotation('get_stats'):
            stats = utils.get_first(stats)
        return stats
Пример #4
0
    def evaluate(self, global_step, rng, **unused_kwargs):
        """See base class."""
        if not self._evaluating:
            self._eval_init()

        global_step = np.array(utils.get_first(global_step))
        ema_params = utils.get_first(self._ema_params)
        ema_network_state = utils.get_first(self._ema_network_state)
        rng = utils.get_first(rng)

        # Evaluate using the ema params.
        results, predictions = self._evaluate_with_ensemble(
            ema_params, ema_network_state, rng)
        results['global_step'] = global_step

        # Store predictions if we got a path.
        self._maybe_save_predictions(predictions, global_step)

        return results
Пример #5
0
 def step(self, global_step, *unused_args, **unused_kwargs):
     if self._train_input is None:
         self._initialize_train()
     batch = next(self._train_input)
     out = self.train_fn(params=self._params,
                         opt_state=self._opt_state,
                         batch=batch,
                         global_step=global_step)
     self._params = out['params']
     self._opt_state = out['opt_state']
     self._step = global_step
     return jl_utils.get_first(out['metrics'])
Пример #6
0
  def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray,
               **unused_kwargs) -> chex.ArrayTree:
    """See Jaxline base class."""
    if self.forward is None:
      self._eval_init()

    if self.config.ema:
      params = utils.get_first(self._ema_params)
      state = utils.get_first(self._ema_network_state)
    else:
      params = utils.get_first(self._params)
      state = utils.get_first(self._network_state)
    rng = utils.get_first(rng)

    split = self.config.evaluation.split
    predictions, scalars = self._get_predictions(
        params, state, rng,
        utils.py_prefetch(
            functools.partial(
                self._build_numpy_dataset_iterator, split, is_training=False)))
    self._maybe_save_predictions(predictions, split, global_step[0])
    return scalars
Пример #7
0
    def _eval_epoch(self, rng):
        """Evaluates an epoch."""
        num_samples = 0.
        summed_scalars = None

        params = jl_utils.get_first(self._params)
        state = jl_utils.get_first(self._state)

        for inputs in self._build_eval_input():
            num_samples += inputs['labels'].shape[0]
            scalars = self._eval_batch(params, state, inputs, rng)

            # Accumulate the sum of scalars for each step.
            scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
            if summed_scalars is None:
                summed_scalars = scalars
            else:
                summed_scalars = jax.tree_multimap(jnp.add, summed_scalars,
                                                   scalars)

        mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
        return mean_scalars
Пример #8
0
    def step(self, global_step: int, rng: jnp.ndarray, *unused_args,
             **unused_kwargs):
        """See base class."""

        if self._train_input is None:
            self._initialize_train()

        inputs = next(self._train_input)

        self._params, self._state, self._opt_state, scalars = (
            self._update_func(self._params, self._state, self._opt_state,
                              inputs, rng, global_step))

        scalars = jl_utils.get_first(scalars)
        return scalars
Пример #9
0
  def step(self, global_step: jnp.ndarray, rng: jnp.ndarray, **unused_args):
    """See Jaxline base class."""
    if self.loss is None:
      self._train_init()

    graph = next(self._train_input)
    out = self.update_parameters(
        self._params,
        self._ema_params,
        self._network_state,
        self._ema_network_state,
        self._opt_state,
        global_step,
        rng,
        graph._asdict())
    (self._params, self._ema_params, self._network_state,
     self._ema_network_state, self._opt_state, scalars) = out
    return utils.get_first(scalars)
Пример #10
0
  def snapshot_state(self) -> Mapping[str, jnp.ndarray]:
    """Takes a frozen copy of the current experiment state for checkpointing.

    Returns:
      A mapping from experiment attributes to names to stored under in the
        snapshot.
    """
    snapshot_state = {}
    if not self.CHECKPOINT_ATTRS and not self.NON_BROADCAST_CHECKPOINT_ATTRS:
      logging.warning(
          "Your experiment's self.CHECKPOINT_ATTRS and "
          "self.NON_BROADCAST_CHECKPOINT_ATTRS are empty. Your job will not "
          "checkpoint any state or parameters.")
    for attr_name, chk_name in self.CHECKPOINT_ATTRS.items():
      snapshot_state[chk_name] = utils.get_first(getattr(self, attr_name))
    for attr_name, chk_name in self.NON_BROADCAST_CHECKPOINT_ATTRS.items():
      snapshot_state[chk_name] = getattr(self, attr_name)
    return snapshot_state
Пример #11
0
 def step(self, global_step, rng, *unused_args, **unused_kwargs):
     if self._train_input is None:
         self._initialize_train()
     inputs = next(self._train_input)
     out = self.train_fn(params=self._params,
                         states=self._state,
                         opt_states=self._opt_state,
                         inputs=inputs,
                         rng=rng,
                         global_step=global_step,
                         ema_params=self._ema_params,
                         ema_states=self._ema_state)
     self._params, self._state = out['params'], out['states']
     self._opt_state = out['opt_states']
     self._ema_params, self._ema_state = out['ema_params'], out[
         'ema_states']
     self.opt.plugin(self._opt_state)
     return jl_utils.get_first(out['metrics'])
Пример #12
0
  def snapshot_state(self) -> Mapping[str, jnp.ndarray]:
    """Takes a frozen copy of the current experiment state for checkpointing.

    Returns:
      A mapping from experiment attributes to names to stored under in the
        snapshot.
    """
    snapshot_state = {}
    if not self.CHECKPOINT_ATTRS and not self.NON_BROADCAST_CHECKPOINT_ATTRS:
      logging.warning(
          "Your experiment's self.CHECKPOINT_ATTRS and "
          "self.NON_BROADCAST_CHECKPOINT_ATTRS are empty. Your job will not "
          "checkpoint any state or parameters. See "
          "learning/deepmind/research/jax/pipeline/examples/imagenet/"
          "experiment.py for an example of specifying values to checkpoint.")
    for attr_name, chk_name in self.CHECKPOINT_ATTRS.items():
      snapshot_state[chk_name] = utils.get_first(getattr(self, attr_name))
    for attr_name, chk_name in self.NON_BROADCAST_CHECKPOINT_ATTRS.items():
      snapshot_state[chk_name] = getattr(self, attr_name)
    return snapshot_state
Пример #13
0
 def eval_epoch(self, params, state, rng):
     host_id = jax.host_id()
     num_samples = 0
     batch_axis = 1
     summed_scalars = None
     # Converting to numpy here allows us to reset the generator.
     eval_input = tfds.as_numpy(self.eval_input)
     for all_inputs in eval_input:
         # The inputs are send to multiple workers.
         inputs = jax.tree_map(lambda x: x[host_id], all_inputs)
         num_samples += jax.device_count(
         ) * inputs['image'].shape[batch_axis]
         scalars = jl_utils.get_first(
             self.eval_fn(params, state, inputs, rng))
         # Accumulate the sum of scalars for each step.
         scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
         if summed_scalars is None:
             summed_scalars = scalars
         else:
             summed_scalars = jax.tree_multimap(jnp.add, summed_scalars,
                                                scalars)
     mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
     return mean_scalars
Пример #14
0
def _save_state_from_in_memory_checkpointer(
    save_path, experiment_class: experiment.AbstractExperiment):
  """Saves experiment state to a checkpoint."""
  logging.info('Saving model.')
  for checkpoint_name, checkpoint in utils.GLOBAL_CHECKPOINT_DICT.items():
    if not checkpoint.history:
      logging.info('Nothing to save in "%s"', checkpoint_name)
      continue

    pickle_nest = checkpoint.history[-1].pickle_nest
    global_step = pickle_nest['global_step']

    state_dict = {'global_step': global_step}
    for attribute, key in experiment_class.CHECKPOINT_ATTRS.items():
      state_dict[key] = utils.get_first(
          getattr(pickle_nest['experiment_module'], attribute))
    save_dir = os.path.join(
        save_path, checkpoint_name, _get_step_date_label(global_step))
    python_state_path = os.path.join(save_dir, 'checkpoint.dill')
    os.makedirs(save_dir, exist_ok=True)
    with open(python_state_path, 'wb') as f:
      dill.dump(state_dict, f)
    logging.info(
        'Saved "%s" checkpoint to %s', checkpoint_name, python_state_path)