def _initialize_experiment_state( self, init_rng: jnp.ndarray, dummy_batch: datasets.Batch, ): """Initialize parameters and opt state if not restoring from checkpoint.""" dummy_graph = dummy_batch.graph # Cast features to float32 so that parameters are as appropriate. dummy_graph = dummy_graph._replace( nodes=jax.tree_map(lambda x: x.astype(np.float32), dummy_graph.nodes), edges=jax.tree_map(lambda x: x.astype(np.float32), dummy_graph.edges), ) init_key = utils.bcast_local_devices(init_rng) p_init = jax.pmap( functools.partial(self.forward.init, is_training=True)) params, network_state = p_init(init_key, dummy_graph) opt_init, _ = self._optimizer( utils.bcast_local_devices(jnp.zeros([], jnp.int32))) opt_state = jax.pmap(opt_init)(params) # For EMA decay to work correctly, params/state must be floats. chex.assert_type(jax.tree_leaves(params), jnp.floating) chex.assert_type(jax.tree_leaves(network_state), jnp.floating) self._params = params self._ema_params = params self._network_state = network_state self._ema_network_state = network_state self._opt_state = opt_state
def _restore_state_to_in_memory_checkpointer(restore_path): """Initializes experiment state from a checkpoint.""" # Load pretrained experiment state. python_state_path = os.path.join(restore_path, 'checkpoint.dill') with open(python_state_path, 'rb') as f: pretrained_state = dill.load(f) logging.info('Restored checkpoint from %s', python_state_path) # Assign state to a dummy experiment instance for the in-memory checkpointer, # broadcasting to devices. dummy_experiment = Experiment( mode='train', init_rng=0, config=FLAGS.config.experiment_kwargs.config) for attribute, key in Experiment.CHECKPOINT_ATTRS.items(): setattr(dummy_experiment, attribute, utils.bcast_local_devices(pretrained_state[key])) jaxline_state = dict( global_step=pretrained_state['global_step'], experiment_module=dummy_experiment) snapshot = utils.SnapshotNT(0, jaxline_state) # Finally, seed the jaxline `utils.InMemoryCheckpointer` global dict. utils.GLOBAL_CHECKPOINT_DICT['latest'] = utils.CheckpointNT( threading.local(), [snapshot])
def _initialize_train(self): self._train_input = jl_utils.py_prefetch(self._build_train_input) total_batch_size = self.config.training.batch_size steps_per_epoch = (self.config.training.images_per_epoch / self.config.training.batch_size) total_steps = self.config.training.n_epochs * steps_per_epoch # Scale by the (negative) learning rate. self._lr_schedule = utils.get_learning_rate_schedule( total_batch_size, steps_per_epoch, total_steps, self.config.optimizer) self._optimizer = utils.make_optimizer(self.config.optimizer, self._lr_schedule) # Check we haven't already restored params if self._params is None: logging.info('Initializing parameters.') inputs = next(self._train_input) init_net = jax.pmap( lambda *a: self.forward.init(*a, is_training=True)) init_opt = jax.pmap(self._optimizer.init) # Init uses the same RNG key on all hosts+devices to ensure everyone # computes the same initial state. init_rng = jl_utils.bcast_local_devices(self.init_rng) self._params, self._state = init_net(init_rng, inputs) self._opt_state = init_opt(self._params)
def test_bcast_local_devices_tree(self): num_devices = jax.local_device_count() tree = utils.bcast_local_devices({ "ones": jnp.ones([]), "zeros": jnp.zeros([]) }) self.assertEqual(tree, { "ones": jnp.ones([num_devices]), "zeros": jnp.zeros([num_devices]) })
def _initialize_training(self, rng): # Initialize inputs. if self.config.emulated_workers > 0: per_device_workers, ragged = divmod(self.config.emulated_workers, jax.host_count()) if ragged: raise ValueError( 'Number of emulated workers must be divisible by the ' 'number of physical workers `jax.host_count()`.') self._repeat_batch = per_device_workers else: self._repeat_batch = 1 self.supervised_train_input = jl_utils.py_prefetch( self._supervised_train_dataset) if self.config.training.extra_data_path is None: self.extra_train_input = None else: self.extra_train_input = jl_utils.py_prefetch( self._extra_train_dataset) self.normalize_fn = datasets.cifar10_normalize # Optimizer. self.optimizer = utils.sgd_momentum(self.config.training.learning_rate, momentum=.9, nesterov=True) # Initialize parameters. if self._params is None: logging.info( 'Initializing parameters randomly rather than restoring ' 'from checkpoint.') # Create inputs to initialize the network state. images, _, _ = jax.pmap(self.concatenate)( next(self.supervised_train_input), next(self.extra_train_input) if self.extra_train_input is not None else None) images = jax.pmap(self.normalize_fn)(images) # Initialize weights and biases. init_net = jax.pmap( lambda *a: self.model.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(rng) self._params, self._state = init_net(init_rng, images) # Setup weight averaging. if self.config.training.swa_decay > 0: self._avg_params = self._params else: self._avg_params = None # Initialize optimizer state. init_opt = jax.pmap(self.optimizer.init, axis_name='i') self._opt_state = init_opt(self._params) # Initialize step function. self.train_fn = jax.pmap(self._train_fn, axis_name='i', donate_argnums=(0, 1, 2, 3))
def restore_from_snapshot(self, snapshot_state: Mapping[Text, jnp.array]): """Restores experiment state from a snapshot. Args: snapshot_state: A mapping from experiment attributes to names they are stored under in the snapshot. """ for attr_name, chk_name in self.CHECKPOINT_ATTRS.items(): value = utils.bcast_local_devices(snapshot_state[chk_name]) setattr(self, attr_name, value) for attr_name, chk_name in self.NON_BROADCAST_CHECKPOINT_ATTRS.items(): setattr(self, attr_name, snapshot_state[chk_name])
def _initialize_train(self): self._train_input = self._build_train_input() # Initialize net and EMA copy of net if no params available. if self._params is None: inputs = next(self._train_input) init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(self.init_rng) self._params, self._state = init_net(init_rng, inputs) if self.config.use_ema: self._ema_params, self._ema_state = init_net(init_rng, inputs) num_params = hk.data_structures.tree_size(self._params) logging.info( f'Net parameters: {num_params / jax.local_device_count()}') self._make_opt()
def _initialize_train(self): self._train_input = self._build_train_input() if self._params is None: input_shape = (1, self.config.image_size, self.config.image_size, 3) inputs = jnp.ones(input_shape, jnp.float32) init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True), axis_name='i') init_rng = jl_utils.bcast_local_devices(self.init_rng) self._params = init_net(init_rng, inputs) num_params = count_parameters(self._params) logging.info( f'Net params: {num_params / jax.local_device_count()}') self._make_opt() self._opt_state = self._opt.init(self._params)
def _train_init(self): self.loss = hk.transform_with_state(self._loss) self._train_input = utils.py_prefetch( lambda: self._build_numpy_dataset_iterator('train', is_training=True)) init_stacked_graphs = next(self._train_input) init_key = utils.bcast_local_devices(self.init_rng) p_init = jax.pmap(self.loss.init) self._params, self._network_state = p_init(init_key, **init_stacked_graphs._asdict()) # Learning rate scheduling. lr_schedule = optax.warmup_cosine_decay_schedule( **self.config.optimizer.lr_schedule) self.optimizer = getattr(optax, self.config.optimizer.name)( learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs) self._opt_state = jax.pmap(self.optimizer.init)(self._params) self.update_parameters = jax.pmap(self._update_parameters, axis_name='i') if self.config.ema: self._ema_params = self._params self._ema_network_state = self._network_state
def write(attributes, broadcast=False): for attr_name, chk_name in attributes.items(): value = snapshot_state[chk_name] if broadcast: value = utils.bcast_local_devices(value) setattr(self, attr_name, value)
def evaluate(experiment_class, config, checkpointer, writer, jaxline_mode=None): """Main evaluation loop.""" if jaxline_mode is None: jaxline_mode = FLAGS.jaxline_mode logging.info("Evaluating with config:\n%s", config) global_step = 0 eval_rng = jax.random.PRNGKey(config.random_seed) experiment = _initialize_experiment(experiment_class, jaxline_mode, eval_rng, config.experiment_kwargs) if config.best_model_eval_metric and jax.host_id() == 0: # Initialize best state. best_state = checkpointer.get_experiment_state("best") best_state.best_eval_metric_value = float("-inf") best_state.best_model_eval_metric = config.best_model_eval_metric # Will evaluate the latest checkpoint in the directory. state = checkpointer.get_experiment_state("latest") state.global_step = global_step state.experiment_module = experiment state.train_step_rng = None eval_rng = jnp.broadcast_to(eval_rng, (jax.local_device_count(), ) + eval_rng.shape) eval_rng = jax.pmap(functools.partial(utils.specialize_rng_host_device, axis_name="i", mode=config.random_mode_eval), axis_name="i")(eval_rng) if config.one_off_evaluate: checkpointer.restore("latest") global_step_devices = utils.bcast_local_devices( jnp.asarray(state.global_step)) scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)( global_step=global_step_devices, rng=eval_rng, writer=writer) if writer is not None: writer.write_scalars(state.global_step, scalar_values) logging.info("Evaluated specific checkpoint, exiting.") return old_checkpoint_path = None initial_weights_are_evaluated = False while True: checkpoint_path = checkpointer.restore_path("latest") if (checkpoint_path is None and config.eval_initial_weights and not initial_weights_are_evaluated): # Skip restoring a checkpoint and directly call evaluate if # `config.eval_initial_weights` but don"t do it more than once. initial_weights_are_evaluated = True else: if checkpoint_path in (None, old_checkpoint_path): logging.info( "Checkpoint %s invalid or already evaluated, waiting.", checkpoint_path) time.sleep(10) continue checkpointer.restore("latest") global_step_devices = utils.bcast_local_devices( jnp.asarray(state.global_step)) scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)( global_step=global_step_devices, rng=eval_rng, writer=writer) if writer is not None: writer.write_scalars(state.global_step, scalar_values) old_checkpoint_path = checkpoint_path # Decide whether to save a "best checkpoint". if config.best_model_eval_metric and jax.host_id() == 0: if config.best_model_eval_metric not in scalar_values: raise ValueError( f"config.best_model_eval_metric has been specified " f"as {config.best_model_eval_metric}, but this key " f"was not returned by the evaluate method") current_eval_metric_value = scalar_values[ config.best_model_eval_metric] old_eval_metric_value = best_state.best_eval_metric_value if old_eval_metric_value < current_eval_metric_value: logging.info("%s: %s > %s, saving new best checkpoint.", config.best_model_eval_metric, current_eval_metric_value, old_eval_metric_value) best_state = checkpointer.get_experiment_state("best") best_state.global_step = state.global_step best_state.experiment_module = experiment best_state.best_eval_metric_value = current_eval_metric_value best_state.train_step_rng = state.train_step_rng checkpointer.save("best") if state.global_step >= config.training_steps: logging.info("Last checkpoint (iteration %d) evaluated, exiting.", state.global_step) break
def test_bcast_local_devices_empty_tree(self): self.assertIsNone(utils.bcast_local_devices(None)) self.assertEqual(utils.bcast_local_devices({}), {})
def test_bcast_local_devices(self): self.assertEqual(utils.bcast_local_devices(jnp.zeros([])), jnp.zeros([jax.local_device_count()])) self.assertEqual(utils.bcast_local_devices(jnp.ones([])), jnp.ones([jax.local_device_count()]))
def train( experiment_class, config, checkpointer: utils.Checkpointer, writer: Optional[utils.Writer], periodic_actions=(), ): """Main training loop.""" logging.info("Training with config:\n%s", config) is_chief = jax.host_id() == 0 rng = jax.random.PRNGKey(config.random_seed) with utils.log_activity("experiment init"): experiment = _initialize_experiment(experiment_class, "train", rng, config.experiment_kwargs) state = checkpointer.get_experiment_state("latest") state.global_step = 0 state.experiment_module = experiment state.train_step_rng = utils.bcast_local_devices(rng) if checkpointer.can_be_restored("latest"): with utils.log_activity("checkpoint restore"): checkpointer.restore("latest") periodic_actions += (utils.PeriodicAction( _log_outputs, interval_type=config.logging_interval_type or config.interval_type, interval=config.log_tensors_interval), ) if config.train_checkpoint_all_hosts or is_chief: if config.save_checkpoint_interval > 0: periodic_actions += (utils.PeriodicAction( lambda *_: checkpointer.save("latest"), interval_type=(config.checkpoint_interval_type or config.interval_type), interval=config.save_checkpoint_interval, run_async=False), ) # run_async True would not be thread-safe. if is_chief: if writer is not None: def write_scalars(global_step: int, scalar_values): writer.write_scalars(global_step, scalar_values) periodic_actions += (utils.PeriodicAction( write_scalars, interval_type=(config.logging_interval_type or config.interval_type), interval=config.log_train_data_interval, log_all_data=config.log_all_train_data), ) for pa in periodic_actions: pa.update_time(time.time(), state.global_step) experiment.train_loop(config, state, periodic_actions, writer) if is_chief: with utils.log_activity("final checkpoint"): checkpointer.save("latest") # Join all async periodic actions that are unfinished. for pa in periodic_actions: pa.wait_to_finish() # We occasionally see errors when the final checkpoint is being written if # the other hosts exit. Here we force all hosts to participate in one final # collective so the non-master hosts cannot exit before the master writes out # the final checkpoint. utils.rendezvous()