Пример #1
0
    def _initialize_training(self, rng):
        # Initialize inputs.
        if self.config.emulated_workers > 0:
            per_device_workers, ragged = divmod(self.config.emulated_workers,
                                                jax.host_count())
            if ragged:
                raise ValueError(
                    'Number of emulated workers must be divisible by the '
                    'number of physical workers `jax.host_count()`.')
            self._repeat_batch = per_device_workers
        else:
            self._repeat_batch = 1
        self.supervised_train_input = jl_utils.py_prefetch(
            self._supervised_train_dataset)
        if self.config.training.extra_data_path is None:
            self.extra_train_input = None
        else:
            self.extra_train_input = jl_utils.py_prefetch(
                self._extra_train_dataset)
        self.normalize_fn = datasets.cifar10_normalize

        # Optimizer.
        self.optimizer = utils.sgd_momentum(self.config.training.learning_rate,
                                            momentum=.9,
                                            nesterov=True)

        # Initialize parameters.
        if self._params is None:
            logging.info(
                'Initializing parameters randomly rather than restoring '
                'from checkpoint.')
            # Create inputs to initialize the network state.
            images, _, _ = jax.pmap(self.concatenate)(
                next(self.supervised_train_input), next(self.extra_train_input)
                if self.extra_train_input is not None else None)
            images = jax.pmap(self.normalize_fn)(images)
            # Initialize weights and biases.
            init_net = jax.pmap(
                lambda *a: self.model.init(*a, is_training=True),
                axis_name='i')
            init_rng = jl_utils.bcast_local_devices(rng)
            self._params, self._state = init_net(init_rng, images)
            # Setup weight averaging.
            if self.config.training.swa_decay > 0:
                self._avg_params = self._params
            else:
                self._avg_params = None
            # Initialize optimizer state.
            init_opt = jax.pmap(self.optimizer.init, axis_name='i')
            self._opt_state = init_opt(self._params)

        # Initialize step function.
        self.train_fn = jax.pmap(self._train_fn,
                                 axis_name='i',
                                 donate_argnums=(0, 1, 2, 3))
Пример #2
0
    def testBadFunction(self):
        def _bad_function():
            raise ValueError

        iterable = utils.py_prefetch(_bad_function)
        with self.assertRaises(ValueError):
            next(iterable)
Пример #3
0
    def _initialize_train(self):
        self._train_input = jl_utils.py_prefetch(self._build_train_input)

        total_batch_size = self.config.training.batch_size
        steps_per_epoch = (self.config.training.images_per_epoch /
                           self.config.training.batch_size)
        total_steps = self.config.training.n_epochs * steps_per_epoch
        # Scale by the (negative) learning rate.
        self._lr_schedule = utils.get_learning_rate_schedule(
            total_batch_size, steps_per_epoch, total_steps,
            self.config.optimizer)

        self._optimizer = utils.make_optimizer(self.config.optimizer,
                                               self._lr_schedule)

        # Check we haven't already restored params
        if self._params is None:
            logging.info('Initializing parameters.')

            inputs = next(self._train_input)

            init_net = jax.pmap(
                lambda *a: self.forward.init(*a, is_training=True))
            init_opt = jax.pmap(self._optimizer.init)

            # Init uses the same RNG key on all hosts+devices to ensure everyone
            # computes the same initial state.
            init_rng = jl_utils.bcast_local_devices(self.init_rng)

            self._params, self._state = init_net(init_rng, inputs)
            self._opt_state = init_opt(self._params)
Пример #4
0
    def testBadFunctionIteration(self):
        def _bad_iterable():
            yield 1
            raise ValueError

        iterable = utils.py_prefetch(_bad_iterable)
        self.assertEqual(next(iterable), 1)
        with self.assertRaises(ValueError):
            next(iterable)
Пример #5
0
    def _train_init(self):
        iterator = self._build_numpy_dataset_iterator('train',
                                                      is_training=True)
        self._train_input = utils.py_prefetch(lambda: iterator)
        dummy_batch = next(self._train_input)

        if self._params is None:
            self._initialize_experiment_state(self.init_rng, dummy_batch)
        self._update_func = jax.pmap(
            self._update_func,
            axis_name='i',
            donate_argnums=3,
        )
        self._training = True
Пример #6
0
  def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray,
               **unused_kwargs) -> chex.ArrayTree:
    """See Jaxline base class."""
    if self.forward is None:
      self._eval_init()

    if self.config.ema:
      params = utils.get_first(self._ema_params)
      state = utils.get_first(self._ema_network_state)
    else:
      params = utils.get_first(self._params)
      state = utils.get_first(self._network_state)
    rng = utils.get_first(rng)

    split = self.config.evaluation.split
    predictions, scalars = self._get_predictions(
        params, state, rng,
        utils.py_prefetch(
            functools.partial(
                self._build_numpy_dataset_iterator, split, is_training=False)))
    self._maybe_save_predictions(predictions, split, global_step[0])
    return scalars
Пример #7
0
  def _train_init(self):
    self.loss = hk.transform_with_state(self._loss)
    self._train_input = utils.py_prefetch(
        lambda: self._build_numpy_dataset_iterator('train', is_training=True))
    init_stacked_graphs = next(self._train_input)
    init_key = utils.bcast_local_devices(self.init_rng)
    p_init = jax.pmap(self.loss.init)
    self._params, self._network_state = p_init(init_key,
                                               **init_stacked_graphs._asdict())

    # Learning rate scheduling.
    lr_schedule = optax.warmup_cosine_decay_schedule(
        **self.config.optimizer.lr_schedule)

    self.optimizer = getattr(optax, self.config.optimizer.name)(
        learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs)

    self._opt_state = jax.pmap(self.optimizer.init)(self._params)
    self.update_parameters = jax.pmap(self._update_parameters, axis_name='i')
    if self.config.ema:
      self._ema_params = self._params
      self._ema_network_state = self._network_state
Пример #8
0
 def testBaseCase(self):
     self.assertEqual(list(utils.py_prefetch(lambda: range(100))),
                      list(range(100)))
Пример #9
0
 def testEmpty(self):
     self.assertEqual(list(utils.py_prefetch(lambda: ())), [])