def _sweep(self, global_step_id: int) -> None: """Deletes or preserves managed checkpoints.""" if not self._max_to_keep: return kept_checkpoints = [] maybe_delete_checkpoints = [] for checkpoint in self._checkpoint_history.checkpoints: if (self._last_kept_checkpoint_datetime is not None and (from_timestamp(checkpoint.timestamp_sec) <= self._last_kept_checkpoint_datetime)): kept_checkpoints.append(checkpoint) else: maybe_delete_checkpoints.append(checkpoint) if not self.use_multi_host_flax: py_utils.sync_global_devices( 'checkpoint_manager:begin_delete_checkpoints:' f'step_{global_step_id}') while len(maybe_delete_checkpoints) > self._max_to_keep: checkpoint = maybe_delete_checkpoints.pop(0) if (self._keep_interval_timedelta and (not kept_checkpoints or self._last_kept_checkpoint_datetime is None or ((from_timestamp(checkpoint.timestamp_sec) - self._last_kept_checkpoint_datetime) >= self._keep_interval_timedelta))): kept_checkpoints.append(checkpoint) self._last_kept_checkpoint_datetime = from_timestamp( checkpoint.timestamp_sec) continue self._delete_checkpoint(checkpoint) if not self.use_multi_host_flax: py_utils.sync_global_devices( 'checkpoint_manager:end_delete_checkpoints:' f'step_{global_step_id}') self._checkpoint_history = self._create_checkpoint_history() for c in itertools.chain(kept_checkpoints, maybe_delete_checkpoints): logging.info('Keeping checkpoint: (step: %d, timestamp: %s)', c.global_step_id, c.timestamp_sec) self._checkpoint_history.checkpoints.add().CopyFrom(c)
def _save_checkpoint_file(self, global_step_id: int, key: str = '') -> None: """Saves the checkpoint file with latest checkpoint metadata. Note: This method overrides previous version of the checkpoint file. Args: global_step_id: The current global step. key: A key to add to the synchronization string. """ py_utils.sync_global_devices( f'checkpoint_manager:begin_save_checkpoint_file:{key}_{global_step_id}' ) if self.use_multi_host_flax or jax.process_index() == 0: with tf.io.gfile.GFile(self.checkpoint_filename, 'wb') as writer: writer.write(self._checkpoint_history.SerializeToString()) py_utils.sync_global_devices( f'checkpoint_manager:end_save_checkpoint_file:{key}_{global_step_id}' )
def _restore_checkpoint_gda( train_state: Optional[train_states.TrainState], checkpoint_dir: str, global_mesh: Optional[maps.Mesh], state_specs: Optional[train_states.TrainState], step: Optional[int] = None) -> train_states.TrainState: """Restores a checkpoint using JAX GDA deserialization mechanism.""" if not tf.io.gfile.exists(checkpoint_dir) or not tf.io.gfile.listdir( checkpoint_dir): if train_state is not None and step is None: logging.info( 'GDA checkpoint restore did not find checkpoint_dir %s; ' 'Return train_state passed in', checkpoint_dir) return train_state raise FileNotFoundError( f'No checkpoint found for restore in {checkpoint_dir}') if step is None: checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir) tmp_checkpoint_dirnames = [ x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x) ] if tmp_checkpoint_dirnames: logging.warn('Found incompletely saved checkpoints %s; skipping them', tmp_checkpoint_dirnames) sorted_dirnames = sorted( [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)]) if not sorted_dirnames: raise FileNotFoundError( f'No checkpoint found for restore in {checkpoint_dir}') latest_checkpoint_dirname = sorted_dirnames[-1] step = get_step_from_checkpoint_asset(latest_checkpoint_dirname) checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step) logging.info('Found latest checkpoint: %s', checkpoint_step_dir) else: checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step) if not tf.io.gfile.exists(checkpoint_step_dir) or not tf.io.gfile.listdir( checkpoint_step_dir): raise FileNotFoundError( f'No checkpoint found for restore in {checkpoint_step_dir}') logging.info('GDA checkpoint restore started...') if train_state is not None: leaves, treedef = jax.tree_util.tree_flatten(train_state) partition_spec_leaves, _ = jax.tree_util.tree_flatten(state_specs) nested_names = _extract_nested_prefix_names(train_state) global_shapes = jax.tree_map(lambda x: x.shape, leaves) else: partition_spec_leaves, treedef = jax.tree_util.tree_flatten(state_specs) nested_names = _extract_nested_prefix_names(state_specs) global_shapes = None flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names) ckpt_paths = [ os.path.join(checkpoint_step_dir, x).rstrip('/') for x in flattened_nested_names ] tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths) train_state_gda = gda_serialization.run_deserialization( [global_mesh] * len(tspecs), partition_spec_leaves, tspecs, global_shapes=global_shapes) restored_train_state = jax.tree_util.tree_unflatten(treedef, train_state_gda) # Barrier across all processes to ensure all restore finish. py_utils.sync_global_devices('Wait for checkpoint restore from ' f'{checkpoint_step_dir} to finish.') logging.info('Successfully restored GDA checkpoint at %s!', checkpoint_step_dir) return restored_train_state
def _save_checkpoint_gda(train_state: train_states.TrainState, checkpoint_dir: str, overwrite: bool, step: int) -> None: """Saves a checkpoint using JAX GDA serialization mechanism. Note that all JAX processes must call _save_checkpoint_gda in sync because each process may only have a slice of the global data. Args: train_state: A partitioned train_state that is a Pytree of GlobalDeviceArray. checkpoint_dir: Full path to parent checkpoint_dir. overwrite: Whether to allow overwriting an existing target directory. step: Step to save checkpoint for. """ if not overwrite: # Does not contain directory path, only dirname is returned. checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir) # Delete tmp directories if any. if jax.process_index() == 0: tmp_checkpoint_dirnames = [ x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x) ] if tmp_checkpoint_dirnames: logging.warn('Found incompletely saved checkpoints %s; deleting them', tmp_checkpoint_dirnames) for x in tmp_checkpoint_dirnames: tf.io.gfile.rmtree(os.path.join(checkpoint_dir, x)) # Note we must barrier across all processes after the tmp directory delete. py_utils.sync_global_devices('Wait for checkpoint tmp dir deletions to ' 'finish.') sorted_dirnames = sorted( [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)]) if sorted_dirnames: latest_checkpoint_dirname = sorted_dirnames[-1] previous_step = get_step_from_checkpoint_asset(latest_checkpoint_dirname) if previous_step >= step: logging.warning( 'A more recent checkpoint `%d` has already been saved compared ' 'to the current timestep `%d`. Skip saving a checkpoint.', previous_step, step) return checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step) checkpoint_step_tmp_dir = _make_tmp_checkpoint_dir( checkpoint_dir, step, sync_timestamp=True) logging.info('Saving to a tmp checkpoint dir %s', checkpoint_step_tmp_dir) nested_names = _extract_nested_prefix_names(train_state) flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names) if jax.process_index() == 0: # Create the tmp parent dir. tf.io.gfile.makedirs(checkpoint_step_tmp_dir) with futures.ThreadPoolExecutor() as executor: ckpt_paths = list( executor.map(_mkdir_path, flattened_nested_names, [checkpoint_step_tmp_dir] * len(flattened_nested_names))) py_utils.sync_global_devices('Wait for checkpoint tmp dir and subdirs ' f'creation {checkpoint_step_tmp_dir} to finish.') tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths) leaves, _ = jax.tree_util.tree_flatten(train_state) gda_serialization.run_serialization(leaves, tspecs) # Note we must barrier across all processes before the directory rename. py_utils.sync_global_devices('Wait for checkpoint chunk writes to ' f'{checkpoint_step_tmp_dir} to finish.') if jax.process_index() == 0: # Rename temporary checkpoint directory to its final location. logging.info('Renaming %s to %s', checkpoint_step_tmp_dir, checkpoint_step_dir) tf.io.gfile.rename(checkpoint_step_tmp_dir, checkpoint_step_dir) logging.info('Finished saving GDA checkpoint for step `%s` to `%s`.', step, checkpoint_step_dir)
def _save_checkpoint_flax(train_state: train_states.TrainState, checkpoint_dir: str, overwrite: bool, unreplicate: bool, step: int, use_multi_host: bool) -> None: """Saves a checkpoint using Flax serialization mechanism.""" if not overwrite: previous_filename = latest_checkpoint(checkpoint_dir) if previous_filename: previous_step = int(previous_filename.rsplit('_', 1)[-1]) if previous_step >= step: logging.warning( 'A more recent checkpoint `%d` has already been saved compared ' 'to the current timestep `%d`. Skip saving a checkpoint.', previous_step, step) return # Assume data parallel-only model for now and retrieve train states # from the first replica only. def maybe_unreplicate(data): if unreplicate: return jax.device_get(jax_utils.unreplicate(data)) else: return jax.device_get(data) # Extract/flatten data structure to store to disk. Flax requires a flattened # data structure to be passed to the checkpointer. flattened_state, pytree_state = jax.tree_flatten( maybe_unreplicate(train_state)) checkpoint_target = { 'flattened_state': flattened_state, # Saves a serialized version of the pytree structure to detect potential # mismatch caused by different versions of saver/restorer. 'str_pytree_state': str(pytree_state), } prefix = CHECKPOINT_PREFIX if use_multi_host: # Notes: # 1. We currently don't broadcast / synchronize the timestamp across # all the JAX processes. # 2. Flax checkpointing already saves the checkpoint file into a temporary # file that is ultimately moved. We, hence, don't need to add a second # layer of temporary files for single-host checkpointing. timestamp = _to_timestamp(datetime.datetime.utcnow()) prefix = f'{TMP_PREFIX}{timestamp}.{prefix}' checkpoints.save_checkpoint( checkpoint_dir, checkpoint_target, step, prefix=prefix, keep=_MAX_CHECKPOINT_FLAX, overwrite=overwrite) if use_multi_host: py_utils.sync_global_devices( f'Renaming temporary checkpoint files at step {step} into their final ' 'destination.') tmp_filename = os.path.join(checkpoint_dir, f'{prefix}{step}') new_filename = os.path.join(checkpoint_dir, f'{CHECKPOINT_PREFIX}{step}') logging.debug('Renaming %s to %s', tmp_filename, new_filename) tf.io.gfile.rename(tmp_filename, new_filename)
def decode_once_spmd_model( task_p: InstantiableParams, input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], checkpoint_type: CheckpointType, restore_checkpoint_dir: str, restore_checkpoint_step: Optional[int], ) -> None: """Runs the decoding once on the entire decoder datasets for SPMD model. Args: task_p: Params for the task that encapsulates an SPMD model. input_p: List of input params to be decoded. job_log_dir: Directory for the job logs. checkpoint_type: Type of model checkpointing method to use. restore_checkpoint_dir: The directory from which to restore checkpoint. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. """ # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) if restore_checkpoint_dir: restore_checkpoint_parent_dir = restore_checkpoint_dir if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: # TODO(zhouwk): add sanity check on number of subdirs and number of # processes and fail early if unequal. restore_checkpoint_dir = os.path.join(restore_checkpoint_dir, f'{jax.process_index():03d}') multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) sample_inputs = input_p[0].Instantiate().get_next() inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype, sample_inputs) model_p = task_p.model # TODO(b/198356509): This is a hack for now as we need to change some # annotations for mode='decode'. A future cl will move this logic # to a more generic model_p.update_sharding_params_v1(mode='decode'). model_p.lm = model_p.lm.cls.set_sharding_params_v1( model_p.lm, replica_axis=model_p.lm.mesh_axis_names[0], data_axis=model_p.lm.mesh_axis_names[1], mdl_axis=model_p.lm.mesh_axis_names[2], device_ids_mesh=model_p.lm.device_mesh, mesh_axis_names=model_p.lm.mesh_axis_names, mode='decode') mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) jax_task = task_p.Instantiate() global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: if restore_checkpoint_dir: model = jax_task.model model.instantiate_variable_configs() # Get the metadata from variables instead of actually instantiating them. partitioned_specs = jax_task.create_train_state_partition_specs( model.vars, is_eval=True) # Instantiate the TrainState directly from the checkpoint. partitioned_train_state = checkpoints.restore_checkpoint( None, restore_checkpoint_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs, step=restore_checkpoint_step) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{restore_checkpoint_parent_dir}') decode_step_fn, inputs_partition_spec = ( trainer_lib.get_partitioned_spmd_model_decode_fn( jax_task, init_key, partitioned_train_state, partitioned_specs, inputs_shape)) else: # When restore is not specified, randomly initiate the train_state. (partitioned_train_state, inputs_partition_spec, partitioned_specs, decode_step_fn) = trainer_lib.partition_spmd_model_decode( task_p, init_key, inputs_shape) logging.info('partitioned_train_state: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, decode_key = jax.random.split(prng_key) logging.info('eval prng_key: %s', decode_key) spmd_decode_step_fn = functools.partial(decode_step_fn, partitioned_train_state.mdl_vars, decode_key, partitioned_train_state.step) num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p ] inputs = [p.Instantiate() for p in input_p] decodes = [list() for _ in input_p] process_id = jax.process_index() for split, num_split_steps in enumerate(num_steps): logging.info('Start decoding on input %s', input_p[split].name) step_num = 0 while num_split_steps < 0 or step_num < num_split_steps: step_num += 1 try: batch = inputs[split].get_next() except (tf.errors.OutOfRangeError, StopIteration): break if jax.config.jax_parallel_functions_output_gda: batch = py_utils.create_gda(batch, inputs_shape, global_mesh, inputs_partition_spec) _, out = spmd_decode_step_fn(batch) # Output is fully replicated now, so it's ok to unreplicate it by # retrieving from device 0 only. out = py_utils.maybe_unreplicate_gda(out) global_batch_size = next(iter(out.values())).shape[0] logging.info('Finished decoding input batch %d with %d examples', step_num, global_batch_size) # Manually shard the output per each jax process. # We require that all fields in the output is batch major. if global_batch_size % jax.process_count() != 0: raise ValueError(f'Global batch size {global_batch_size} must divide ' f'jax process count {jax.process_count()}') for k, v in out.items(): if v.shape[0] != global_batch_size: raise ValueError('We require that all fields in the decode output ' 'to have batch size as the first dim, got shape=' f'{v.shape} with key={k}, expect batch size = ' f'{global_batch_size}') per_process_batch_size = global_batch_size // jax.process_count() def shard(x, per_process_batch_size=per_process_batch_size): return x[(process_id * per_process_batch_size):((process_id + 1) * per_process_batch_size)] out = jax.tree_map(shard, out) _, processed = jax_task.model.process_decode_out(inputs[split], out) decodes[split].extend(processed) logging.info('Finished processing decoded input batch %d', step_num) basedir = os.path.join(job_log_dir, 'decoder_out') dirnames = _get_dir_names(input_p) filename = _get_filename( py_utils.maybe_unreplicate_gda(partitioned_train_state.step)) for s in dirnames: dir_path = os.path.join(basedir, s) if not tf.io.gfile.exists(dir_path): tf.io.gfile.makedirs(dir_path) filenames = [os.path.join(basedir, s, filename) for s in dirnames] for split, output_file in enumerate(filenames): logging.info('Writing decoder output to %s with %d entries', output_file, len(decodes[split])) io_utils.WriteKeyValuePairs(output_file, decodes[split])
def evaluate_spmd_model( task_p: InstantiableParams, eval_input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], checkpoint_type: CheckpointType, ) -> None: """Runs the evaluation loop on the entire test dataset for SPMD model. Args: task_p: Params of the task encapsulating an SPMD model. eval_input_p: List of Params for the eval data pipelines. job_log_dir: Directory for the job logs. checkpoint_type: Type of model checkpointing method to use. """ logging.info('Using SPMD sharding for model parallelism.') eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p] # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = os.path.join(job_log_dir, 'checkpoints') # Note that GDA checkpoint requires all processes to participate in # checkpointing but it does not require a separate checkpoint_dir per process. if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: checkpoint_task_dir = os.path.join(checkpoint_dir, f'{jax.process_index():03d}') else: checkpoint_task_dir = checkpoint_dir multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) def get_shape_dtype(x): y = jax.ShapeDtypeStruct(x.shape, x.dtype) return y # Do not ues eval_input_pipelines[0] directly. sample_model_inputs = eval_input_p[0].Instantiate().get_next() inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs) model_p = task_p.model mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = ( trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape)) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs) logging.info('partitioned_train_state: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) if multi_host_checkpointing: py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}') # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, eval_key = jax.random.split(prng_key) logging.info('eval prng_key: %s', eval_key) logging.info('Evaluation loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_eval_dirs = [ os.path.join(summary_base_dir, f'eval_{split}') for split, _ in enumerate(eval_input_p) ] num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p] last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) with contextlib.ExitStack() as exit_stack: eval_summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_eval_dirs ] while True: step_i = int(jax.device_get(partitioned_train_state.step)) eval_step_fn = functools.partial(eval_step, partitioned_train_state.mdl_vars, eval_key, partitioned_train_state.step) # Run the eval loop. model_utils.run_eval_loop_over_test_splits( num_steps, eval_step_fn, eval_summary_writers, step_i, eval_input_pipelines, eval_inputs_partition_specs, inputs_shape, global_mesh, reshard_inputs=False) # If the last check point evaluated matches max train steps, exit. if last_checkpoint is not None: last_ckpt_step = checkpoints.get_step_from_checkpoint_asset( last_checkpoint) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) while new_checkpoint == last_checkpoint: # Sleep for a minute. time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) # There must be a new checkpoint here. logging.info('Found new checkpoint: %s', new_checkpoint) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{checkpoint_dir}') last_checkpoint = new_checkpoint
def test_sync_global_devices(self): py_utils.sync_global_devices('sync')
def _init_checkpoint_history(self) -> None: """Initializes the checkpoint history and sets related class attributes.""" self._last_saved_checkpoint_step: int = None self._last_kept_checkpoint_datetime: Optional[datetime.datetime] = None if not tf.io.gfile.exists(self.checkpoint_filename): self._checkpoint_history = self._create_checkpoint_history() return # Read the previous checkpoints file and performs a sanity check. self._checkpoint_history = self._read_checkpoint_file() last_saved_timestamp = ( self._checkpoint_history.checkpoints[-1].timestamp_sec) current_datetime = datetime.datetime.utcnow() if current_datetime < from_timestamp(last_saved_timestamp): # Time seems to have reversed itself. logging.warning( 'datetime.datetime.utcnow() returned a value `%s` behind the last ' 'saved checkpoint timestamp.', from_timestamp(last_saved_timestamp) - current_datetime) # Add few of the checkpoints to the `kept` list. kept_checkpoints = [] if self._keep_interval_timedelta is None: maybe_delete_checkpoints = list( self._checkpoint_history.checkpoints) else: maybe_delete_checkpoints = [] oldest_kept_timestamp = None for checkpoint in self._checkpoint_history.checkpoints: if (oldest_kept_timestamp is None or ((from_timestamp(oldest_kept_timestamp) - from_timestamp(checkpoint.timestamp_sec)) >= self._keep_interval_timedelta)): oldest_kept_timestamp = checkpoint.timestamp_sec kept_checkpoints.append(checkpoint) if self._last_kept_checkpoint_datetime is None: self._last_kept_checkpoint_datetime = (from_timestamp( checkpoint.timestamp_sec)) else: maybe_delete_checkpoints.append(checkpoint) # Only keep at most `max_to_keep` non-kept checkpoints. Delete the old ones. if not self.use_multi_host_flax: py_utils.sync_global_devices( 'checkpoint_manager:begin_delete_checkpoints:' f'init_{self._checkpoint_history.checkpoints[-1].global_step_id}' ) for i, checkpoint in enumerate(reversed(maybe_delete_checkpoints)): if self._max_to_keep is None or i < self._max_to_keep: kept_checkpoints.append(checkpoint) else: self._delete_checkpoint(checkpoint) if not self.use_multi_host_flax: py_utils.sync_global_devices( 'checkpoint_manager:end_delete_checkpoints:' f'init_{self._checkpoint_history.checkpoints[-1].global_step_id}' ) # Finally create a new CheckpointHistory and save a new checkpoint file. kept_checkpoints = sorted( kept_checkpoints, key=lambda c: from_timestamp(c.timestamp_sec)) latest_global_step = kept_checkpoints[-1].global_step_id self._last_saved_checkpoint_step = latest_global_step self._checkpoint_history = self._create_checkpoint_history() for c in kept_checkpoints: self._checkpoint_history.checkpoints.add().CopyFrom(c) self._save_checkpoint_file(latest_global_step)
def train_and_evaluate_spmd_model( task_p: InstantiableParams, train_input_p: InstantiableParams, job_log_dir: Optional[str], checkpoint_manager: checkpoint_managers.CheckpointManager, checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str], restore_checkpoint_step: Optional[int], eval_input_p: Optional[Sequence[InstantiableParams]]) -> None: """Runs the training and evaluation loop. Args: task_p: Params for task encapsulating the SPMD model. train_input_p: Params for the train data pipeline. job_log_dir: Directory for the job logs. checkpoint_manager: A checkpoint manager controlling how often to save and delete checkpoints. checkpoint_type: The type of checkpoint to use. restore_checkpoint_dir: If set, the directory from which to restore checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory instead. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. eval_input_p: Optional list of params for the eval input pipelines. """ logging.info('Using SPMD sharding for model parallelism.') model_p = task_p.model if eval_input_p is not None: eval_input_pipelines = [ input_p.Instantiate() for input_p in eval_input_p ] # Do not mutate eval_input_pipelines itself. Instantiate a new one # to get sample input. sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next() eval_test_inputs_shape = tf.nest.map_structure( py_utils.get_global_input_shape_dtype, sample_eval_model_inputs) eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs( model_p.mesh_axis_names, eval_test_inputs_shape) # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = _checkpoint_dir(job_log_dir) restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir # Note that GDA checkpoint requires all processes to participate in # checkpointing but it does not require a separate checkpoint_dir per process. if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: checkpoint_task_dir = os.path.join(checkpoint_dir, f'{jax.process_index():03d}') restore_checkpoint_task_dir = os.path.join( restore_checkpoint_dir, f'{jax.process_index():03d}') else: checkpoint_task_dir = checkpoint_dir restore_checkpoint_task_dir = restore_checkpoint_dir multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) if jax.process_index() == 0: tf.io.gfile.makedirs(checkpoint_dir) if multi_host_checkpointing: # Block all hosts until directory is ready. py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}') logging.info('Retrieving model inputs for shape info.') train_input_for_shape = train_input_p.Instantiate() model_inputs_for_shape = train_input_for_shape.get_next() inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype, model_inputs_for_shape) mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: (partitioned_train_state, train_state_pspecs, inputs_pspecs, train_step, eval_step, total_num_params) = trainer_lib.partition_spmd_model( task_p, init_key, inputs_shape) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, restore_checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=train_state_pspecs, step=restore_checkpoint_step) logging.info( 'partitioned_train_state shapes ' '(global shape for GDA, host-local shape for non-GDA: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{checkpoint_dir}') train_p = task_p.train initial_global_step = int( jax.device_get( py_utils.maybe_unreplicate_gda(partitioned_train_state.step))) logging.info('Model initial global_step=%d', initial_global_step) _update_latest_model_step(train_input_p, initial_global_step, train_p.eval_interval_steps) train_input_pipeline = train_input_p.Instantiate() # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, train_key, eval_key = jax.random.split(prng_key, 3) logging.info('train prng_key: %s', train_key) logging.info('eval prng_key: %s', eval_key) logging.info('Training loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_train_dir = os.path.join(summary_base_dir, 'train') summary_eval_dir = os.path.join(summary_base_dir, 'eval_train') summary_writer = summary_utils.get_summary_writer if eval_input_p is not None: summary_eval_test_dirs = [ os.path.join(summary_base_dir, f'eval_test_{split}') for split, _ in enumerate(eval_input_p) ] # We either run p.eval_loop_num_batches steps or one epoch (when supported # by a resettable input) per eval loop during training. When # p.reset_for_eval is set to True, we run the eval loop until # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be # triggered either because input pipeline has reached the end of the input # sequence, or a pre-determined num_batches has reached. eval_num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p ] else: summary_eval_test_dirs = [] with contextlib.ExitStack() as exit_stack: train_summary_writer = exit_stack.enter_context( summary_writer(summary_train_dir)) eval_summary_writer = exit_stack.enter_context( summary_writer(summary_eval_dir)) eval_test_summary_writers = [ exit_stack.enter_context(summary_writer(d)) for d in summary_eval_test_dirs ] # This only prints the view from the first host machine. summary_utils.write_model_structure(train_summary_writer, partitioned_train_state, is_vars_replicated=False) summary_utils.write_total_num_params(train_summary_writer, total_num_params) summary_last_time = time.time() summary_last_step = None step_i = int( jax.device_get( py_utils.maybe_unreplicate_gda( partitioned_train_state.step))) # Start the train loop. Make sure all at the same step. py_utils.sync_global_devices( f'Start training loop from step: {step_i}') while True: logging.debug('step=`%d`: Beginning', step_i) if step_i >= train_p.num_train_steps: logging.info( 'Training loop completed (step (`%d`) greater than ' 'num_train_step (`%d`).', step_i, train_p.num_train_steps) break if summary_last_step is None: summary_last_step = step_i - 1 if checkpoint_manager.should_save(step_i): logging.info('Saving a ckpt at step: %d', step_i) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:saving:{checkpoint_dir}:step-{step_i}' ) if multi_host_checkpointing or jax.process_index() == 0: checkpoints.save_checkpoint( partitioned_train_state, checkpoint_task_dir, checkpoint_type=checkpoint_type, state_specs=train_state_pspecs, unreplicate=False) checkpoint_manager.save_metadata(global_step_id=step_i) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:saved:{checkpoint_dir}:step-{step_i}' ) # Get new model inputs if step_i <= _N_STEPS_WARMUP_LOGGING: logging.info('step=`%d`: Retrieving model inputs.', step_i) logging.debug(' Retrieving inputs.') model_inputs = train_input_pipeline.get_next() if jax.config.jax_parallel_functions_output_gda: if step_i <= _N_STEPS_WARMUP_LOGGING: start = time.time() py_utils.assert_same_shape_and_dtype( inputs_shape, tf.nest.map_structure( py_utils.get_global_input_shape_dtype, model_inputs)) model_inputs = py_utils.create_gda(model_inputs, inputs_shape, global_mesh, inputs_pspecs) if step_i <= _N_STEPS_WARMUP_LOGGING: logging.info('GDA train batch input creation time %s', time.time() - start) logging.debug(' Retrieved inputs.') logging.debug(' Performing train_step().') with jax.profiler.StepTraceAnnotation('train', step_num=step_i): (partitioned_train_state, loss, metrics, per_example_out, summary_tensors) = train_step(partitioned_train_state, train_key, model_inputs) logging.debug(' Completed train_step().') logging.debug(' Writing summaries (attempt).') if summary_utils.write_summary_every_n_steps( partitioned_train_state, train_summary_writer, step_i, train_p.summary_interval_steps, loss, metrics, per_example_out, summary_tensors, train_p.norm_summary_interval_steps, summary_last_time, summary_last_step, unreplicate_mdl_vars=False, unreplicate_metrics=False): summary_last_time = time.time() summary_last_step = step_i step_i = int( py_utils.maybe_unreplicate_gda( partitioned_train_state.step)) else: # Increment train step locally to avoid an explicit device sync. step_i += 1 logging.debug(' Wrote summaries (attempted).') # Run eval at regular step interval. if step_i % train_p.eval_interval_steps == 0: logging.debug(' Starting eval_step().') logging.debug(' Retrieving eval model_inputs.') eval_inputs = train_input_pipeline.get_next() if jax.config.jax_parallel_functions_output_gda: eval_inputs = py_utils.create_gda( eval_inputs, inputs_shape, global_mesh, inputs_pspecs) logging.debug(' Retrieved eval model_inputs.') logging.debug( ' Performing eval_step() runs on training split.') eval_step_fn = functools.partial( eval_step, partitioned_train_state.mdl_vars, eval_key, partitioned_train_state.step) loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step( eval_inputs, eval_step_fn, reshard_inputs=False) logging.debug( ' Completed eval_step() runs on training split.') logging.info('step=`%d`', step_i) logging.info(' eval loss: %s', loss) logging.info(' mean_metrics: %s', mean_metrics) logging.info(' summary_tensors: %s', summary_tensors) if step_i % train_p.summary_interval_steps == 0: logging.debug(' Writing eval summaries.') summary_utils.write_summary_entry( eval_summary_writer, step_i, loss, mean_metrics, summary_tensors, unreplicate_metrics=False) logging.debug(' Wrote eval summaries.') # If we have eval test then also evaluate on test. if eval_input_p is not None: logging.debug( ' Performing eval_step() runs on test splits.') model_utils.run_eval_loop_over_test_splits( eval_num_steps, eval_step_fn, eval_test_summary_writers, step_i, eval_input_pipelines, eval_test_inputs_pspecs, eval_test_inputs_shape, global_mesh, reshard_inputs=False) logging.debug( ' Completed eval_step() runs on test splits.') logging.debug('step=`%d`: End', step_i - 1)