def load_model(self): if self.network: logger.info( 'Attempting to reload model when model is loaded. Returning') return with self.lock: logger.info('Loading Model') start = timer() logger.info(f"JAX Devices: {jax.device_count()}") logger.info(f"JAX Runtime Initialized in {timer(start):.06} secs") mesh_shape = (jax.device_count() // self.params['cores_per_replica'], self.params['cores_per_replica']) self.devices = np.array(jax.devices()).reshape(mesh_shape) self.total_batch = self.params[ 'per_replica_batch'] * jax.device_count( ) // self.params['cores_per_replica'] * 8 maps.thread_resources.env = maps.ResourceEnv( maps.Mesh(self.devices, ('dp', 'mp'))) network = CausalTransformer(self.params) logger.info(f'Loading Checkpoint') network.state = read_ckpt(network.state, "/app/model/", self.devices.shape[1]) logger.info( f"GPTJ Network loaded in {timer(start):.06} secs. Total Batch Size: {self.total_batch}" ) del network.state["opt_state"] network.state = network.move_xmap( network.state, np.zeros(self.params['cores_per_replica'])) self.network = network
def background(self): logger.info(f'Init Background') maps.thread_resources.env = maps.ResourceEnv( maps.Mesh(self.devices, ('dp', 'mp'))) while True: batch, qids = [], [] while len(batch) <= self.total_batch: try: req = self.queue.get(block=False) logger.info(f'Got Queue Item: {req}') batch.append(req['item']) qids.append(req['qidx']) except Empty: if len(batch): break else: time.sleep(0.01) batch_size = len(batch) logger.info(f'Working on Batch: {batch_size} - {qids}') while len(batch) < self.total_batch: batch.append(self.placeholder_item) start = timer() results = self.infer_batch(batch) for res, qid in zip(results, qids): self.queue_ids[qid].put(res) logger.info( f'Completed Current Batch of {batch_size} Items in {timer(start):.2f} secs' )
def _pjit(inp): if isinstance(inp, GlobalDeviceArray): if inp.is_fully_replicated: return inp.local_data(0).to_py() global_mesh = inp.mesh in_axis_resources = FROM_GDA else: # DA/SDA/np.array will be sharded based on global_mesh.local_mesh. # Shape of local_mesh will always be (1, local_device_count()) devices = np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()) global_mesh = maps.Mesh(devices, ('processes', 'local_devices')) in_axis_resources = P('processes') if inp.ndim == 0 or not tiled: inp = np.expand_dims(inp, axis=0) with maps.Mesh(global_mesh.devices, global_mesh.axis_names): out = pjit(lambda x: x, in_axis_resources=in_axis_resources, out_axis_resources=None)(inp) return out.local_data(0).to_py()
def test_pjit_inherits_effects(self): if jax.default_backend() not in {'gpu', 'tpu'}: raise unittest.SkipTest("pjit only supports GPU and TPU backends") def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_unordered_print_with_xmap(self): def f(x): debug_print("{}", x, ordered=False) f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', axis_resources={'a': 'dev'}) with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']): with capture_stdout() as output: f(jnp.arange(40)) jax.effects_barrier() lines = [f"{i}\n" for i in range(40)] self._assertLinesEqual(output(), "".join(lines))
def test_pjit_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
"seq": 2048, "cores_per_replica": 1, # only running on one GPU "per_replica_batch": 1, } per_replica_batch = params["per_replica_batch"] cores_per_replica = params["cores_per_replica"] seq = params["seq"] params["sampler"] = nucleaus_sample # here we "remove" the optimizer parameters from the model (as we don't need them for inference) params["optimizer"] = optax.scale(0) devices = np.array([jax.devices()[0]]).reshape((1, 1)) maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') network = CausalTransformer(params) start = time.time() # here we load a checkpoint which was written with 8 shards into 1 shard network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica) def infer(context, top_p=0.9, temp=1.0, gen_len=512):
def decode_once_spmd_model( task_p: InstantiableParams, input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], checkpoint_type: CheckpointType, restore_checkpoint_dir: str, restore_checkpoint_step: Optional[int], ) -> None: """Runs the decoding once on the entire decoder datasets for SPMD model. Args: task_p: Params for the task that encapsulates an SPMD model. input_p: List of input params to be decoded. job_log_dir: Directory for the job logs. checkpoint_type: Type of model checkpointing method to use. restore_checkpoint_dir: The directory from which to restore checkpoint. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. """ # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) if restore_checkpoint_dir: restore_checkpoint_parent_dir = restore_checkpoint_dir if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: # TODO(zhouwk): add sanity check on number of subdirs and number of # processes and fail early if unequal. restore_checkpoint_dir = os.path.join(restore_checkpoint_dir, f'{jax.process_index():03d}') multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) sample_inputs = input_p[0].Instantiate().get_next() inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype, sample_inputs) model_p = task_p.model # TODO(b/198356509): This is a hack for now as we need to change some # annotations for mode='decode'. A future cl will move this logic # to a more generic model_p.update_sharding_params_v1(mode='decode'). model_p.lm = model_p.lm.cls.set_sharding_params_v1( model_p.lm, replica_axis=model_p.lm.mesh_axis_names[0], data_axis=model_p.lm.mesh_axis_names[1], mdl_axis=model_p.lm.mesh_axis_names[2], device_ids_mesh=model_p.lm.device_mesh, mesh_axis_names=model_p.lm.mesh_axis_names, mode='decode') mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) jax_task = task_p.Instantiate() global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: if restore_checkpoint_dir: model = jax_task.model model.instantiate_variable_configs() # Get the metadata from variables instead of actually instantiating them. partitioned_specs = jax_task.create_train_state_partition_specs( model.vars, is_eval=True) # Instantiate the TrainState directly from the checkpoint. partitioned_train_state = checkpoints.restore_checkpoint( None, restore_checkpoint_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs, step=restore_checkpoint_step) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{restore_checkpoint_parent_dir}') decode_step_fn, inputs_partition_spec = ( trainer_lib.get_partitioned_spmd_model_decode_fn( jax_task, init_key, partitioned_train_state, partitioned_specs, inputs_shape)) else: # When restore is not specified, randomly initiate the train_state. (partitioned_train_state, inputs_partition_spec, partitioned_specs, decode_step_fn) = trainer_lib.partition_spmd_model_decode( task_p, init_key, inputs_shape) logging.info('partitioned_train_state: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, decode_key = jax.random.split(prng_key) logging.info('eval prng_key: %s', decode_key) spmd_decode_step_fn = functools.partial(decode_step_fn, partitioned_train_state.mdl_vars, decode_key, partitioned_train_state.step) num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in input_p ] inputs = [p.Instantiate() for p in input_p] decodes = [list() for _ in input_p] process_id = jax.process_index() for split, num_split_steps in enumerate(num_steps): logging.info('Start decoding on input %s', input_p[split].name) step_num = 0 while num_split_steps < 0 or step_num < num_split_steps: step_num += 1 try: batch = inputs[split].get_next() except (tf.errors.OutOfRangeError, StopIteration): break if jax.config.jax_parallel_functions_output_gda: batch = py_utils.create_gda(batch, inputs_shape, global_mesh, inputs_partition_spec) _, out = spmd_decode_step_fn(batch) # Output is fully replicated now, so it's ok to unreplicate it by # retrieving from device 0 only. out = py_utils.maybe_unreplicate_gda(out) global_batch_size = next(iter(out.values())).shape[0] logging.info('Finished decoding input batch %d with %d examples', step_num, global_batch_size) # Manually shard the output per each jax process. # We require that all fields in the output is batch major. if global_batch_size % jax.process_count() != 0: raise ValueError(f'Global batch size {global_batch_size} must divide ' f'jax process count {jax.process_count()}') for k, v in out.items(): if v.shape[0] != global_batch_size: raise ValueError('We require that all fields in the decode output ' 'to have batch size as the first dim, got shape=' f'{v.shape} with key={k}, expect batch size = ' f'{global_batch_size}') per_process_batch_size = global_batch_size // jax.process_count() def shard(x, per_process_batch_size=per_process_batch_size): return x[(process_id * per_process_batch_size):((process_id + 1) * per_process_batch_size)] out = jax.tree_map(shard, out) _, processed = jax_task.model.process_decode_out(inputs[split], out) decodes[split].extend(processed) logging.info('Finished processing decoded input batch %d', step_num) basedir = os.path.join(job_log_dir, 'decoder_out') dirnames = _get_dir_names(input_p) filename = _get_filename( py_utils.maybe_unreplicate_gda(partitioned_train_state.step)) for s in dirnames: dir_path = os.path.join(basedir, s) if not tf.io.gfile.exists(dir_path): tf.io.gfile.makedirs(dir_path) filenames = [os.path.join(basedir, s, filename) for s in dirnames] for split, output_file in enumerate(filenames): logging.info('Writing decoder output to %s with %d entries', output_file, len(decodes[split])) io_utils.WriteKeyValuePairs(output_file, decodes[split])
def evaluate_spmd_model( task_p: InstantiableParams, eval_input_p: Sequence[InstantiableParams], job_log_dir: Optional[str], checkpoint_type: CheckpointType, ) -> None: """Runs the evaluation loop on the entire test dataset for SPMD model. Args: task_p: Params of the task encapsulating an SPMD model. eval_input_p: List of Params for the eval data pipelines. job_log_dir: Directory for the job logs. checkpoint_type: Type of model checkpointing method to use. """ logging.info('Using SPMD sharding for model parallelism.') eval_input_pipelines = [input_p.Instantiate() for input_p in eval_input_p] # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = os.path.join(job_log_dir, 'checkpoints') # Note that GDA checkpoint requires all processes to participate in # checkpointing but it does not require a separate checkpoint_dir per process. if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: checkpoint_task_dir = os.path.join(checkpoint_dir, f'{jax.process_index():03d}') else: checkpoint_task_dir = checkpoint_dir multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) def get_shape_dtype(x): y = jax.ShapeDtypeStruct(x.shape, x.dtype) return y # Do not ues eval_input_pipelines[0] directly. sample_model_inputs = eval_input_p[0].Instantiate().get_next() inputs_shape = tf.nest.map_structure(get_shape_dtype, sample_model_inputs) model_p = task_p.model mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: partitioned_train_state, partitioned_specs, eval_inputs_partition_specs, _, eval_step, _ = ( trainer_lib.partition_spmd_model(task_p, init_key, inputs_shape)) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs) logging.info('partitioned_train_state: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) if multi_host_checkpointing: py_utils.sync_global_devices(f'checkpointer:restored:{checkpoint_dir}') # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, eval_key = jax.random.split(prng_key) logging.info('eval prng_key: %s', eval_key) logging.info('Evaluation loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_eval_dirs = [ os.path.join(summary_base_dir, f'eval_{split}') for split, _ in enumerate(eval_input_p) ] num_steps = [-1 if p.reset_for_eval else 1 for p in eval_input_p] last_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) with contextlib.ExitStack() as exit_stack: eval_summary_writers = [ exit_stack.enter_context(summary_utils.get_summary_writer(d)) for d in summary_eval_dirs ] while True: step_i = int(jax.device_get(partitioned_train_state.step)) eval_step_fn = functools.partial(eval_step, partitioned_train_state.mdl_vars, eval_key, partitioned_train_state.step) # Run the eval loop. model_utils.run_eval_loop_over_test_splits( num_steps, eval_step_fn, eval_summary_writers, step_i, eval_input_pipelines, eval_inputs_partition_specs, inputs_shape, global_mesh, reshard_inputs=False) # If the last check point evaluated matches max train steps, exit. if last_checkpoint is not None: last_ckpt_step = checkpoints.get_step_from_checkpoint_asset( last_checkpoint) exceeded_ckpt = last_ckpt_step + task_p.train.save_interval_steps if exceeded_ckpt >= task_p.train.num_train_steps: break new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) while new_checkpoint == last_checkpoint: # Sleep for a minute. time.sleep(60) new_checkpoint = checkpoints.latest_checkpoint(checkpoint_dir) # There must be a new checkpoint here. logging.info('Found new checkpoint: %s', new_checkpoint) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=partitioned_specs) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{checkpoint_dir}') last_checkpoint = new_checkpoint
def train_and_evaluate_spmd_model( task_p: InstantiableParams, train_input_p: InstantiableParams, job_log_dir: Optional[str], checkpoint_manager: checkpoint_managers.CheckpointManager, checkpoint_type: CheckpointType, restore_checkpoint_dir: Optional[str], restore_checkpoint_step: Optional[int], eval_input_p: Optional[Sequence[InstantiableParams]]) -> None: """Runs the training and evaluation loop. Args: task_p: Params for task encapsulating the SPMD model. train_input_p: Params for the train data pipeline. job_log_dir: Directory for the job logs. checkpoint_manager: A checkpoint manager controlling how often to save and delete checkpoints. checkpoint_type: The type of checkpoint to use. restore_checkpoint_dir: If set, the directory from which to restore checkpoint. If unset, use job_log_dir's `checkpoints` subdirectory instead. restore_checkpoint_step: If set, the checkpoint step to restore. If unset, try to restore from the latest checkpoint if any. eval_input_p: Optional list of params for the eval input pipelines. """ logging.info('Using SPMD sharding for model parallelism.') model_p = task_p.model if eval_input_p is not None: eval_input_pipelines = [ input_p.Instantiate() for input_p in eval_input_p ] # Do not mutate eval_input_pipelines itself. Instantiate a new one # to get sample input. sample_eval_model_inputs = eval_input_p[0].Instantiate().get_next() eval_test_inputs_shape = tf.nest.map_structure( py_utils.get_global_input_shape_dtype, sample_eval_model_inputs) eval_test_inputs_pspecs = trainer_lib.get_input_partition_specs( model_p.mesh_axis_names, eval_test_inputs_shape) # TODO(bf-jax): Retrieve the seeds from the model definition instead. prng_key = jax.random.PRNGKey(1234) prng_key, init_key = jax.random.split(prng_key) checkpoint_dir = _checkpoint_dir(job_log_dir) restore_checkpoint_dir = restore_checkpoint_dir or checkpoint_dir # Note that GDA checkpoint requires all processes to participate in # checkpointing but it does not require a separate checkpoint_dir per process. if checkpoint_type == CheckpointType.CHECKPOINT_MULTI_HOST_FLAX: checkpoint_task_dir = os.path.join(checkpoint_dir, f'{jax.process_index():03d}') restore_checkpoint_task_dir = os.path.join( restore_checkpoint_dir, f'{jax.process_index():03d}') else: checkpoint_task_dir = checkpoint_dir restore_checkpoint_task_dir = restore_checkpoint_dir multi_host_checkpointing = bool(checkpoint_type in { CheckpointType.CHECKPOINT_MULTI_HOST_FLAX, CheckpointType.CHECKPOINT_GDA }) if jax.process_index() == 0: tf.io.gfile.makedirs(checkpoint_dir) if multi_host_checkpointing: # Block all hosts until directory is ready. py_utils.sync_global_devices(f'checkpointer:makedirs:{checkpoint_dir}') logging.info('Retrieving model inputs for shape info.') train_input_for_shape = train_input_p.Instantiate() model_inputs_for_shape = train_input_for_shape.get_next() inputs_shape = tf.nest.map_structure(py_utils.get_global_input_shape_dtype, model_inputs_for_shape) mesh_shape = model_p.device_mesh.shape device_mesh = mesh_utils.create_device_mesh(mesh_shape) logging.info('device_mesh: %s', device_mesh) global_mesh = maps.Mesh(device_mesh, model_p.mesh_axis_names) with global_mesh: (partitioned_train_state, train_state_pspecs, inputs_pspecs, train_step, eval_step, total_num_params) = trainer_lib.partition_spmd_model( task_p, init_key, inputs_shape) partitioned_train_state = checkpoints.restore_checkpoint( partitioned_train_state, restore_checkpoint_task_dir, global_mesh=global_mesh, checkpoint_type=checkpoint_type, state_specs=train_state_pspecs, step=restore_checkpoint_step) logging.info( 'partitioned_train_state shapes ' '(global shape for GDA, host-local shape for non-GDA: %s', jax.tree_map(lambda x: x.shape, partitioned_train_state)) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:restored:{checkpoint_dir}') train_p = task_p.train initial_global_step = int( jax.device_get( py_utils.maybe_unreplicate_gda(partitioned_train_state.step))) logging.info('Model initial global_step=%d', initial_global_step) _update_latest_model_step(train_input_p, initial_global_step, train_p.eval_interval_steps) train_input_pipeline = train_input_p.Instantiate() # We do not fold in jax.process_index in contrast to the pmap version and # use a single global key instead to rely on pjit to split for different # replicas. logging.info('root prng_key: %s', prng_key) prng_key, train_key, eval_key = jax.random.split(prng_key, 3) logging.info('train prng_key: %s', train_key) logging.info('eval prng_key: %s', eval_key) logging.info('Training loop starting...') summary_base_dir = os.path.join(job_log_dir, 'summaries') summary_train_dir = os.path.join(summary_base_dir, 'train') summary_eval_dir = os.path.join(summary_base_dir, 'eval_train') summary_writer = summary_utils.get_summary_writer if eval_input_p is not None: summary_eval_test_dirs = [ os.path.join(summary_base_dir, f'eval_test_{split}') for split, _ in enumerate(eval_input_p) ] # We either run p.eval_loop_num_batches steps or one epoch (when supported # by a resettable input) per eval loop during training. When # p.reset_for_eval is set to True, we run the eval loop until # tf.errors.OutOfRangeError (or StopIteration) is raised, which can be # triggered either because input pipeline has reached the end of the input # sequence, or a pre-determined num_batches has reached. eval_num_steps = [ -1 if p.reset_for_eval else p.eval_loop_num_batches for p in eval_input_p ] else: summary_eval_test_dirs = [] with contextlib.ExitStack() as exit_stack: train_summary_writer = exit_stack.enter_context( summary_writer(summary_train_dir)) eval_summary_writer = exit_stack.enter_context( summary_writer(summary_eval_dir)) eval_test_summary_writers = [ exit_stack.enter_context(summary_writer(d)) for d in summary_eval_test_dirs ] # This only prints the view from the first host machine. summary_utils.write_model_structure(train_summary_writer, partitioned_train_state, is_vars_replicated=False) summary_utils.write_total_num_params(train_summary_writer, total_num_params) summary_last_time = time.time() summary_last_step = None step_i = int( jax.device_get( py_utils.maybe_unreplicate_gda( partitioned_train_state.step))) # Start the train loop. Make sure all at the same step. py_utils.sync_global_devices( f'Start training loop from step: {step_i}') while True: logging.debug('step=`%d`: Beginning', step_i) if step_i >= train_p.num_train_steps: logging.info( 'Training loop completed (step (`%d`) greater than ' 'num_train_step (`%d`).', step_i, train_p.num_train_steps) break if summary_last_step is None: summary_last_step = step_i - 1 if checkpoint_manager.should_save(step_i): logging.info('Saving a ckpt at step: %d', step_i) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:saving:{checkpoint_dir}:step-{step_i}' ) if multi_host_checkpointing or jax.process_index() == 0: checkpoints.save_checkpoint( partitioned_train_state, checkpoint_task_dir, checkpoint_type=checkpoint_type, state_specs=train_state_pspecs, unreplicate=False) checkpoint_manager.save_metadata(global_step_id=step_i) if multi_host_checkpointing: py_utils.sync_global_devices( f'checkpointer:saved:{checkpoint_dir}:step-{step_i}' ) # Get new model inputs if step_i <= _N_STEPS_WARMUP_LOGGING: logging.info('step=`%d`: Retrieving model inputs.', step_i) logging.debug(' Retrieving inputs.') model_inputs = train_input_pipeline.get_next() if jax.config.jax_parallel_functions_output_gda: if step_i <= _N_STEPS_WARMUP_LOGGING: start = time.time() py_utils.assert_same_shape_and_dtype( inputs_shape, tf.nest.map_structure( py_utils.get_global_input_shape_dtype, model_inputs)) model_inputs = py_utils.create_gda(model_inputs, inputs_shape, global_mesh, inputs_pspecs) if step_i <= _N_STEPS_WARMUP_LOGGING: logging.info('GDA train batch input creation time %s', time.time() - start) logging.debug(' Retrieved inputs.') logging.debug(' Performing train_step().') with jax.profiler.StepTraceAnnotation('train', step_num=step_i): (partitioned_train_state, loss, metrics, per_example_out, summary_tensors) = train_step(partitioned_train_state, train_key, model_inputs) logging.debug(' Completed train_step().') logging.debug(' Writing summaries (attempt).') if summary_utils.write_summary_every_n_steps( partitioned_train_state, train_summary_writer, step_i, train_p.summary_interval_steps, loss, metrics, per_example_out, summary_tensors, train_p.norm_summary_interval_steps, summary_last_time, summary_last_step, unreplicate_mdl_vars=False, unreplicate_metrics=False): summary_last_time = time.time() summary_last_step = step_i step_i = int( py_utils.maybe_unreplicate_gda( partitioned_train_state.step)) else: # Increment train step locally to avoid an explicit device sync. step_i += 1 logging.debug(' Wrote summaries (attempted).') # Run eval at regular step interval. if step_i % train_p.eval_interval_steps == 0: logging.debug(' Starting eval_step().') logging.debug(' Retrieving eval model_inputs.') eval_inputs = train_input_pipeline.get_next() if jax.config.jax_parallel_functions_output_gda: eval_inputs = py_utils.create_gda( eval_inputs, inputs_shape, global_mesh, inputs_pspecs) logging.debug(' Retrieved eval model_inputs.') logging.debug( ' Performing eval_step() runs on training split.') eval_step_fn = functools.partial( eval_step, partitioned_train_state.mdl_vars, eval_key, partitioned_train_state.step) loss, mean_metrics, summary_tensors = model_utils.run_eval_one_step( eval_inputs, eval_step_fn, reshard_inputs=False) logging.debug( ' Completed eval_step() runs on training split.') logging.info('step=`%d`', step_i) logging.info(' eval loss: %s', loss) logging.info(' mean_metrics: %s', mean_metrics) logging.info(' summary_tensors: %s', summary_tensors) if step_i % train_p.summary_interval_steps == 0: logging.debug(' Writing eval summaries.') summary_utils.write_summary_entry( eval_summary_writer, step_i, loss, mean_metrics, summary_tensors, unreplicate_metrics=False) logging.debug(' Wrote eval summaries.') # If we have eval test then also evaluate on test. if eval_input_p is not None: logging.debug( ' Performing eval_step() runs on test splits.') model_utils.run_eval_loop_over_test_splits( eval_num_steps, eval_step_fn, eval_test_summary_writers, step_i, eval_input_pipelines, eval_test_inputs_pspecs, eval_test_inputs_shape, global_mesh, reshard_inputs=False) logging.debug( ' Completed eval_step() runs on test splits.') logging.debug('step=`%d`: End', step_i - 1)