Esempio n. 1
0
    def _initialize_experiment_state(
        self,
        init_rng: jnp.ndarray,
        dummy_batch: datasets.Batch,
    ):
        """Initialize parameters and opt state if not restoring from checkpoint."""
        dummy_graph = dummy_batch.graph

        # Cast features to float32 so that parameters are as appropriate.
        dummy_graph = dummy_graph._replace(
            nodes=jax.tree_map(lambda x: x.astype(np.float32),
                               dummy_graph.nodes),
            edges=jax.tree_map(lambda x: x.astype(np.float32),
                               dummy_graph.edges),
        )
        init_key = utils.bcast_local_devices(init_rng)
        p_init = jax.pmap(
            functools.partial(self.forward.init, is_training=True))
        params, network_state = p_init(init_key, dummy_graph)
        opt_init, _ = self._optimizer(
            utils.bcast_local_devices(jnp.zeros([], jnp.int32)))
        opt_state = jax.pmap(opt_init)(params)

        # For EMA decay to work correctly, params/state must be floats.
        chex.assert_type(jax.tree_leaves(params), jnp.floating)
        chex.assert_type(jax.tree_leaves(network_state), jnp.floating)

        self._params = params
        self._ema_params = params
        self._network_state = network_state
        self._ema_network_state = network_state
        self._opt_state = opt_state
Esempio n. 2
0
def _restore_state_to_in_memory_checkpointer(restore_path):
  """Initializes experiment state from a checkpoint."""

  # Load pretrained experiment state.
  python_state_path = os.path.join(restore_path, 'checkpoint.dill')
  with open(python_state_path, 'rb') as f:
    pretrained_state = dill.load(f)
  logging.info('Restored checkpoint from %s', python_state_path)

  # Assign state to a dummy experiment instance for the in-memory checkpointer,
  # broadcasting to devices.
  dummy_experiment = Experiment(
      mode='train', init_rng=0, config=FLAGS.config.experiment_kwargs.config)
  for attribute, key in Experiment.CHECKPOINT_ATTRS.items():
    setattr(dummy_experiment, attribute,
            utils.bcast_local_devices(pretrained_state[key]))

  jaxline_state = dict(
      global_step=pretrained_state['global_step'],
      experiment_module=dummy_experiment)
  snapshot = utils.SnapshotNT(0, jaxline_state)

  # Finally, seed the jaxline `utils.InMemoryCheckpointer` global dict.
  utils.GLOBAL_CHECKPOINT_DICT['latest'] = utils.CheckpointNT(
      threading.local(), [snapshot])
Esempio n. 3
0
    def _initialize_train(self):
        self._train_input = jl_utils.py_prefetch(self._build_train_input)

        total_batch_size = self.config.training.batch_size
        steps_per_epoch = (self.config.training.images_per_epoch /
                           self.config.training.batch_size)
        total_steps = self.config.training.n_epochs * steps_per_epoch
        # Scale by the (negative) learning rate.
        self._lr_schedule = utils.get_learning_rate_schedule(
            total_batch_size, steps_per_epoch, total_steps,
            self.config.optimizer)

        self._optimizer = utils.make_optimizer(self.config.optimizer,
                                               self._lr_schedule)

        # Check we haven't already restored params
        if self._params is None:
            logging.info('Initializing parameters.')

            inputs = next(self._train_input)

            init_net = jax.pmap(
                lambda *a: self.forward.init(*a, is_training=True))
            init_opt = jax.pmap(self._optimizer.init)

            # Init uses the same RNG key on all hosts+devices to ensure everyone
            # computes the same initial state.
            init_rng = jl_utils.bcast_local_devices(self.init_rng)

            self._params, self._state = init_net(init_rng, inputs)
            self._opt_state = init_opt(self._params)
Esempio n. 4
0
 def test_bcast_local_devices_tree(self):
     num_devices = jax.local_device_count()
     tree = utils.bcast_local_devices({
         "ones": jnp.ones([]),
         "zeros": jnp.zeros([])
     })
     self.assertEqual(tree, {
         "ones": jnp.ones([num_devices]),
         "zeros": jnp.zeros([num_devices])
     })
Esempio n. 5
0
    def _initialize_training(self, rng):
        # Initialize inputs.
        if self.config.emulated_workers > 0:
            per_device_workers, ragged = divmod(self.config.emulated_workers,
                                                jax.host_count())
            if ragged:
                raise ValueError(
                    'Number of emulated workers must be divisible by the '
                    'number of physical workers `jax.host_count()`.')
            self._repeat_batch = per_device_workers
        else:
            self._repeat_batch = 1
        self.supervised_train_input = jl_utils.py_prefetch(
            self._supervised_train_dataset)
        if self.config.training.extra_data_path is None:
            self.extra_train_input = None
        else:
            self.extra_train_input = jl_utils.py_prefetch(
                self._extra_train_dataset)
        self.normalize_fn = datasets.cifar10_normalize

        # Optimizer.
        self.optimizer = utils.sgd_momentum(self.config.training.learning_rate,
                                            momentum=.9,
                                            nesterov=True)

        # Initialize parameters.
        if self._params is None:
            logging.info(
                'Initializing parameters randomly rather than restoring '
                'from checkpoint.')
            # Create inputs to initialize the network state.
            images, _, _ = jax.pmap(self.concatenate)(
                next(self.supervised_train_input), next(self.extra_train_input)
                if self.extra_train_input is not None else None)
            images = jax.pmap(self.normalize_fn)(images)
            # Initialize weights and biases.
            init_net = jax.pmap(
                lambda *a: self.model.init(*a, is_training=True),
                axis_name='i')
            init_rng = jl_utils.bcast_local_devices(rng)
            self._params, self._state = init_net(init_rng, images)
            # Setup weight averaging.
            if self.config.training.swa_decay > 0:
                self._avg_params = self._params
            else:
                self._avg_params = None
            # Initialize optimizer state.
            init_opt = jax.pmap(self.optimizer.init, axis_name='i')
            self._opt_state = init_opt(self._params)

        # Initialize step function.
        self.train_fn = jax.pmap(self._train_fn,
                                 axis_name='i',
                                 donate_argnums=(0, 1, 2, 3))
Esempio n. 6
0
    def restore_from_snapshot(self, snapshot_state: Mapping[Text, jnp.array]):
        """Restores experiment state from a snapshot.

    Args:
      snapshot_state: A mapping from experiment attributes to names they are
        stored under in the snapshot.
    """
        for attr_name, chk_name in self.CHECKPOINT_ATTRS.items():
            value = utils.bcast_local_devices(snapshot_state[chk_name])
            setattr(self, attr_name, value)
        for attr_name, chk_name in self.NON_BROADCAST_CHECKPOINT_ATTRS.items():
            setattr(self, attr_name, snapshot_state[chk_name])
Esempio n. 7
0
 def _initialize_train(self):
     self._train_input = self._build_train_input()
     # Initialize net and EMA copy of net if no params available.
     if self._params is None:
         inputs = next(self._train_input)
         init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
                             axis_name='i')
         init_rng = jl_utils.bcast_local_devices(self.init_rng)
         self._params, self._state = init_net(init_rng, inputs)
         if self.config.use_ema:
             self._ema_params, self._ema_state = init_net(init_rng, inputs)
         num_params = hk.data_structures.tree_size(self._params)
         logging.info(
             f'Net parameters: {num_params / jax.local_device_count()}')
     self._make_opt()
Esempio n. 8
0
 def _initialize_train(self):
     self._train_input = self._build_train_input()
     if self._params is None:
         input_shape = (1, self.config.image_size, self.config.image_size,
                        3)
         inputs = jnp.ones(input_shape, jnp.float32)
         init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
                             axis_name='i')
         init_rng = jl_utils.bcast_local_devices(self.init_rng)
         self._params = init_net(init_rng, inputs)
         num_params = count_parameters(self._params)
         logging.info(
             f'Net params: {num_params / jax.local_device_count()}')
         self._make_opt()
         self._opt_state = self._opt.init(self._params)
Esempio n. 9
0
  def _train_init(self):
    self.loss = hk.transform_with_state(self._loss)
    self._train_input = utils.py_prefetch(
        lambda: self._build_numpy_dataset_iterator('train', is_training=True))
    init_stacked_graphs = next(self._train_input)
    init_key = utils.bcast_local_devices(self.init_rng)
    p_init = jax.pmap(self.loss.init)
    self._params, self._network_state = p_init(init_key,
                                               **init_stacked_graphs._asdict())

    # Learning rate scheduling.
    lr_schedule = optax.warmup_cosine_decay_schedule(
        **self.config.optimizer.lr_schedule)

    self.optimizer = getattr(optax, self.config.optimizer.name)(
        learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs)

    self._opt_state = jax.pmap(self.optimizer.init)(self._params)
    self.update_parameters = jax.pmap(self._update_parameters, axis_name='i')
    if self.config.ema:
      self._ema_params = self._params
      self._ema_network_state = self._network_state
Esempio n. 10
0
 def write(attributes, broadcast=False):
   for attr_name, chk_name in attributes.items():
     value = snapshot_state[chk_name]
     if broadcast:
       value = utils.bcast_local_devices(value)
     setattr(self, attr_name, value)
Esempio n. 11
0
def evaluate(experiment_class,
             config,
             checkpointer,
             writer,
             jaxline_mode=None):
    """Main evaluation loop."""
    if jaxline_mode is None:
        jaxline_mode = FLAGS.jaxline_mode

    logging.info("Evaluating with config:\n%s", config)

    global_step = 0
    eval_rng = jax.random.PRNGKey(config.random_seed)
    experiment = _initialize_experiment(experiment_class, jaxline_mode,
                                        eval_rng, config.experiment_kwargs)

    if config.best_model_eval_metric and jax.host_id() == 0:
        # Initialize best state.
        best_state = checkpointer.get_experiment_state("best")
        best_state.best_eval_metric_value = float("-inf")
        best_state.best_model_eval_metric = config.best_model_eval_metric

    # Will evaluate the latest checkpoint in the directory.
    state = checkpointer.get_experiment_state("latest")
    state.global_step = global_step
    state.experiment_module = experiment
    state.train_step_rng = None

    eval_rng = jnp.broadcast_to(eval_rng,
                                (jax.local_device_count(), ) + eval_rng.shape)
    eval_rng = jax.pmap(functools.partial(utils.specialize_rng_host_device,
                                          axis_name="i",
                                          mode=config.random_mode_eval),
                        axis_name="i")(eval_rng)

    if config.one_off_evaluate:
        checkpointer.restore("latest")
        global_step_devices = utils.bcast_local_devices(
            jnp.asarray(state.global_step))
        scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)(
            global_step=global_step_devices, rng=eval_rng, writer=writer)
        if writer is not None:
            writer.write_scalars(state.global_step, scalar_values)
        logging.info("Evaluated specific checkpoint, exiting.")
        return

    old_checkpoint_path = None
    initial_weights_are_evaluated = False

    while True:
        checkpoint_path = checkpointer.restore_path("latest")

        if (checkpoint_path is None and config.eval_initial_weights
                and not initial_weights_are_evaluated):
            # Skip restoring a checkpoint and directly call evaluate if
            # `config.eval_initial_weights` but don"t do it more than once.
            initial_weights_are_evaluated = True
        else:
            if checkpoint_path in (None, old_checkpoint_path):
                logging.info(
                    "Checkpoint %s invalid or already evaluated, waiting.",
                    checkpoint_path)
                time.sleep(10)
                continue

            checkpointer.restore("latest")

        global_step_devices = utils.bcast_local_devices(
            jnp.asarray(state.global_step))
        scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)(
            global_step=global_step_devices, rng=eval_rng, writer=writer)
        if writer is not None:
            writer.write_scalars(state.global_step, scalar_values)
        old_checkpoint_path = checkpoint_path
        # Decide whether to save a "best checkpoint".
        if config.best_model_eval_metric and jax.host_id() == 0:
            if config.best_model_eval_metric not in scalar_values:
                raise ValueError(
                    f"config.best_model_eval_metric has been specified "
                    f"as {config.best_model_eval_metric}, but this key "
                    f"was not returned by the evaluate method")
            current_eval_metric_value = scalar_values[
                config.best_model_eval_metric]
            old_eval_metric_value = best_state.best_eval_metric_value
            if old_eval_metric_value < current_eval_metric_value:
                logging.info("%s: %s > %s, saving new best checkpoint.",
                             config.best_model_eval_metric,
                             current_eval_metric_value, old_eval_metric_value)
                best_state = checkpointer.get_experiment_state("best")
                best_state.global_step = state.global_step
                best_state.experiment_module = experiment
                best_state.best_eval_metric_value = current_eval_metric_value
                best_state.train_step_rng = state.train_step_rng
                checkpointer.save("best")

        if state.global_step >= config.training_steps:
            logging.info("Last checkpoint (iteration %d) evaluated, exiting.",
                         state.global_step)
            break
Esempio n. 12
0
 def test_bcast_local_devices_empty_tree(self):
     self.assertIsNone(utils.bcast_local_devices(None))
     self.assertEqual(utils.bcast_local_devices({}), {})
Esempio n. 13
0
    def test_bcast_local_devices(self):
        self.assertEqual(utils.bcast_local_devices(jnp.zeros([])),
                         jnp.zeros([jax.local_device_count()]))

        self.assertEqual(utils.bcast_local_devices(jnp.ones([])),
                         jnp.ones([jax.local_device_count()]))
Esempio n. 14
0
def train(
        experiment_class,
        config,
        checkpointer: utils.Checkpointer,
        writer: Optional[utils.Writer],
        periodic_actions=(),
):
    """Main training loop."""
    logging.info("Training with config:\n%s", config)
    is_chief = jax.host_id() == 0

    rng = jax.random.PRNGKey(config.random_seed)
    with utils.log_activity("experiment init"):
        experiment = _initialize_experiment(experiment_class, "train", rng,
                                            config.experiment_kwargs)

    state = checkpointer.get_experiment_state("latest")
    state.global_step = 0
    state.experiment_module = experiment
    state.train_step_rng = utils.bcast_local_devices(rng)

    if checkpointer.can_be_restored("latest"):
        with utils.log_activity("checkpoint restore"):
            checkpointer.restore("latest")

    periodic_actions += (utils.PeriodicAction(
        _log_outputs,
        interval_type=config.logging_interval_type or config.interval_type,
        interval=config.log_tensors_interval), )

    if config.train_checkpoint_all_hosts or is_chief:
        if config.save_checkpoint_interval > 0:
            periodic_actions += (utils.PeriodicAction(
                lambda *_: checkpointer.save("latest"),
                interval_type=(config.checkpoint_interval_type
                               or config.interval_type),
                interval=config.save_checkpoint_interval,
                run_async=False), )  # run_async True would not be thread-safe.

    if is_chief:
        if writer is not None:

            def write_scalars(global_step: int, scalar_values):
                writer.write_scalars(global_step, scalar_values)

            periodic_actions += (utils.PeriodicAction(
                write_scalars,
                interval_type=(config.logging_interval_type
                               or config.interval_type),
                interval=config.log_train_data_interval,
                log_all_data=config.log_all_train_data), )

    for pa in periodic_actions:
        pa.update_time(time.time(), state.global_step)

    experiment.train_loop(config, state, periodic_actions, writer)

    if is_chief:
        with utils.log_activity("final checkpoint"):
            checkpointer.save("latest")

    # Join all async periodic actions that are unfinished.
    for pa in periodic_actions:
        pa.wait_to_finish()

    # We occasionally see errors when the final checkpoint is being written if
    # the other hosts exit. Here we force all hosts to participate in one final
    # collective so the non-master hosts cannot exit before the master writes out
    # the final checkpoint.
    utils.rendezvous()