def _initialize_training(self, rng): # Initialize inputs. if self.config.emulated_workers > 0: per_device_workers, ragged = divmod(self.config.emulated_workers, jax.host_count()) if ragged: raise ValueError( 'Number of emulated workers must be divisible by the ' 'number of physical workers `jax.host_count()`.') self._repeat_batch = per_device_workers else: self._repeat_batch = 1 self.supervised_train_input = jl_utils.py_prefetch( self._supervised_train_dataset) if self.config.training.extra_data_path is None: self.extra_train_input = None else: self.extra_train_input = jl_utils.py_prefetch( self._extra_train_dataset) self.normalize_fn = datasets.cifar10_normalize # Optimizer. self.optimizer = utils.sgd_momentum(self.config.training.learning_rate, momentum=.9, nesterov=True) # Initialize parameters. if self._params is None: logging.info( 'Initializing parameters randomly rather than restoring ' 'from checkpoint.') # Create inputs to initialize the network state. images, _, _ = jax.pmap(self.concatenate)( next(self.supervised_train_input), next(self.extra_train_input) if self.extra_train_input is not None else None) images = jax.pmap(self.normalize_fn)(images) # Initialize weights and biases. init_net = jax.pmap( lambda *a: self.model.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(rng) self._params, self._state = init_net(init_rng, images) # Setup weight averaging. if self.config.training.swa_decay > 0: self._avg_params = self._params else: self._avg_params = None # Initialize optimizer state. init_opt = jax.pmap(self.optimizer.init, axis_name='i') self._opt_state = init_opt(self._params) # Initialize step function. self.train_fn = jax.pmap(self._train_fn, axis_name='i', donate_argnums=(0, 1, 2, 3))
def testBadFunction(self): def _bad_function(): raise ValueError iterable = utils.py_prefetch(_bad_function) with self.assertRaises(ValueError): next(iterable)
def _initialize_train(self): self._train_input = jl_utils.py_prefetch(self._build_train_input) total_batch_size = self.config.training.batch_size steps_per_epoch = (self.config.training.images_per_epoch / self.config.training.batch_size) total_steps = self.config.training.n_epochs * steps_per_epoch # Scale by the (negative) learning rate. self._lr_schedule = utils.get_learning_rate_schedule( total_batch_size, steps_per_epoch, total_steps, self.config.optimizer) self._optimizer = utils.make_optimizer(self.config.optimizer, self._lr_schedule) # Check we haven't already restored params if self._params is None: logging.info('Initializing parameters.') inputs = next(self._train_input) init_net = jax.pmap( lambda *a: self.forward.init(*a, is_training=True)) init_opt = jax.pmap(self._optimizer.init) # Init uses the same RNG key on all hosts+devices to ensure everyone # computes the same initial state. init_rng = jl_utils.bcast_local_devices(self.init_rng) self._params, self._state = init_net(init_rng, inputs) self._opt_state = init_opt(self._params)
def testBadFunctionIteration(self): def _bad_iterable(): yield 1 raise ValueError iterable = utils.py_prefetch(_bad_iterable) self.assertEqual(next(iterable), 1) with self.assertRaises(ValueError): next(iterable)
def _train_init(self): iterator = self._build_numpy_dataset_iterator('train', is_training=True) self._train_input = utils.py_prefetch(lambda: iterator) dummy_batch = next(self._train_input) if self._params is None: self._initialize_experiment_state(self.init_rng, dummy_batch) self._update_func = jax.pmap( self._update_func, axis_name='i', donate_argnums=3, ) self._training = True
def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray, **unused_kwargs) -> chex.ArrayTree: """See Jaxline base class.""" if self.forward is None: self._eval_init() if self.config.ema: params = utils.get_first(self._ema_params) state = utils.get_first(self._ema_network_state) else: params = utils.get_first(self._params) state = utils.get_first(self._network_state) rng = utils.get_first(rng) split = self.config.evaluation.split predictions, scalars = self._get_predictions( params, state, rng, utils.py_prefetch( functools.partial( self._build_numpy_dataset_iterator, split, is_training=False))) self._maybe_save_predictions(predictions, split, global_step[0]) return scalars
def _train_init(self): self.loss = hk.transform_with_state(self._loss) self._train_input = utils.py_prefetch( lambda: self._build_numpy_dataset_iterator('train', is_training=True)) init_stacked_graphs = next(self._train_input) init_key = utils.bcast_local_devices(self.init_rng) p_init = jax.pmap(self.loss.init) self._params, self._network_state = p_init(init_key, **init_stacked_graphs._asdict()) # Learning rate scheduling. lr_schedule = optax.warmup_cosine_decay_schedule( **self.config.optimizer.lr_schedule) self.optimizer = getattr(optax, self.config.optimizer.name)( learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs) self._opt_state = jax.pmap(self.optimizer.init)(self._params) self.update_parameters = jax.pmap(self._update_parameters, axis_name='i') if self.config.ema: self._ema_params = self._params self._ema_network_state = self._network_state
def testBaseCase(self): self.assertEqual(list(utils.py_prefetch(lambda: range(100))), list(range(100)))
def testEmpty(self): self.assertEqual(list(utils.py_prefetch(lambda: ())), [])