def main(argv): tf.config.experimental.set_visible_devices([], 'GPU') del argv logging.info('*** Starting experiment') gin_configs = FLAGS.gin_configs logging.info('*** Loading Gin configs from: %s', str(gin_configs)) gin.parse_config_files_and_bindings( config_files=gin_configs, bindings=FLAGS.gin_bindings, skip_unknown=True) # Load configurations. exp_config = configs.ExperimentConfig() model_config = configs.ModelConfig(use_stratified_sampling=False) train_config = configs.TrainConfig() eval_config = configs.EvalConfig() # Get directory information. exp_dir = gpath.GPath(FLAGS.base_folder) if exp_config.subname: exp_dir = exp_dir / exp_config.subname logging.info('\texp_dir = %s', exp_dir) if not exp_dir.exists(): exp_dir.mkdir(parents=True, exist_ok=True) summary_dir = exp_dir / 'summaries' / 'eval' logging.info('\tsummary_dir = %s', summary_dir) if not summary_dir.exists(): summary_dir.mkdir(parents=True, exist_ok=True) renders_dir = exp_dir / 'renders' logging.info('\trenders_dir = %s', renders_dir) if not renders_dir.exists(): renders_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir = exp_dir / 'checkpoints' logging.info('\tcheckpoint_dir = %s', checkpoint_dir) logging.info('Starting host %d. There are %d hosts : %s', jax.process_index(), jax.process_count(), str(jax.process_indexs())) logging.info('Found %d accelerator devices: %s.', jax.local_device_count(), str(jax.local_devices())) logging.info('Found %d total devices: %s.', jax.device_count(), str(jax.devices())) rng = random.PRNGKey(20200823) devices_to_use = jax.local_devices() n_devices = len( devices_to_use) if devices_to_use else jax.local_device_count() datasource_spec = exp_config.datasource_spec if datasource_spec is None: datasource_spec = { 'type': exp_config.datasource_type, 'data_dir': FLAGS.data_dir, } logging.info('Creating datasource: %s', datasource_spec) datasource = datasets.from_config( datasource_spec, image_scale=exp_config.image_scale, use_appearance_id=model_config.use_appearance_metadata, use_camera_id=model_config.use_camera_metadata, use_warp_id=model_config.use_warp, use_time=model_config.warp_metadata_encoder_type == 'time', random_seed=exp_config.random_seed, **exp_config.datasource_kwargs) # Get training IDs to evaluate. train_eval_ids = utils.strided_subset( datasource.train_ids, eval_config.num_train_eval) train_eval_iter = datasource.create_iterator(train_eval_ids, batch_size=0) val_eval_ids = utils.strided_subset( datasource.val_ids, eval_config.num_val_eval) val_eval_iter = datasource.create_iterator(val_eval_ids, batch_size=0) test_cameras = datasource.load_test_cameras(count=eval_config.num_test_eval) if test_cameras: test_dataset = datasource.create_cameras_dataset(test_cameras) test_eval_ids = [f'{x:03d}' for x in range(len(test_cameras))] test_eval_iter = datasets.iterator_from_dataset(test_dataset, batch_size=0) else: test_eval_ids = None test_eval_iter = None rng, key = random.split(rng) params = {} model, params['model'] = models.construct_nerf( key, model_config, batch_size=eval_config.chunk, appearance_ids=datasource.appearance_ids, camera_ids=datasource.camera_ids, warp_ids=datasource.warp_ids, near=datasource.near, far=datasource.far, use_warp_jacobian=False, use_weights=False) optimizer_def = optim.Adam(0.0) optimizer = optimizer_def.create(params) init_state = model_utils.TrainState(optimizer=optimizer) del params def _model_fn(key_0, key_1, params, rays_dict, warp_extra): out = model.apply({'params': params}, rays_dict, warp_extra=warp_extra, rngs={ 'coarse': key_0, 'fine': key_1 }, mutable=False) return jax.lax.all_gather(out, axis_name='batch') pmodel_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. _model_fn, in_axes=(0, 0, 0, 0, 0), # Only distribute the data input. devices=devices_to_use, donate_argnums=(3,), # Donate the 'rays' argument. axis_name='batch', ) render_fn = functools.partial(evaluation.render_image, model_fn=pmodel_fn, device_count=n_devices, chunk=eval_config.chunk) last_step = 0 summary_writer = tensorboard.SummaryWriter(str(summary_dir)) while True: if not checkpoint_dir.exists(): logging.info('No checkpoints yet.') time.sleep(10) continue state = checkpoints.restore_checkpoint(checkpoint_dir, init_state) state = jax_utils.replicate(state, devices=devices_to_use) step = int(state.optimizer.state.step[0]) if step <= last_step: logging.info('No new checkpoints (%d <= %d).', step, last_step) time.sleep(10) continue save_dir = renders_dir if eval_config.save_output else None process_iterator( tag='val', item_ids=val_eval_ids, iterator=val_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) process_iterator(tag='train', item_ids=train_eval_ids, iterator=train_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) if test_eval_iter: process_iterator(tag='test', item_ids=test_eval_ids, iterator=test_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) if save_dir: delete_old_renders(renders_dir, eval_config.max_render_checkpoints) if eval_config.eval_once: break if step >= train_config.max_steps: break last_step = step
def main(argv): tf.config.experimental.set_visible_devices([], 'GPU') del argv logging.info('*** Starting experiment') gin_configs = FLAGS.gin_configs logging.info('*** Loading Gin configs from: %s', str(gin_configs)) gin.parse_config_files_and_bindings(config_files=gin_configs, bindings=FLAGS.gin_bindings, skip_unknown=True) # Load configurations. exp_config = configs.ExperimentConfig() model_config = configs.ModelConfig() train_config = configs.TrainConfig() # Get directory information. exp_dir = gpath.GPath(FLAGS.base_folder) if exp_config.subname: exp_dir = exp_dir / exp_config.subname summary_dir = exp_dir / 'summaries' / 'train' checkpoint_dir = exp_dir / 'checkpoints' # Log and create directories if this is the main host. if jax.process_index() == 0: logging.info('exp_dir = %s', exp_dir) if not exp_dir.exists(): exp_dir.mkdir(parents=True, exist_ok=True) logging.info('summary_dir = %s', summary_dir) if not summary_dir.exists(): summary_dir.mkdir(parents=True, exist_ok=True) logging.info('checkpoint_dir = %s', checkpoint_dir) if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True, exist_ok=True) config_str = gin.operative_config_str() logging.info('Configuration: \n%s', config_str) with (exp_dir / 'config.gin').open('w') as f: f.write(config_str) logging.info('Starting host %d. There are %d hosts : %s', jax.process_index(), jax.process_count(), str(jax.process_indexs())) logging.info('Found %d accelerator devices: %s.', jax.local_device_count(), str(jax.local_devices())) logging.info('Found %d total devices: %s.', jax.device_count(), str(jax.devices())) rng = random.PRNGKey(exp_config.random_seed) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(exp_config.random_seed + jax.process_index()) if train_config.batch_size % jax.device_count() != 0: raise ValueError( 'Batch size must be divisible by the number of devices.') devices = jax.local_devices() datasource_spec = exp_config.datasource_spec if datasource_spec is None: datasource_spec = { 'type': exp_config.datasource_type, 'data_dir': FLAGS.data_dir, } logging.info('Creating datasource: %s', datasource_spec) datasource = datasets.from_config( datasource_spec, image_scale=exp_config.image_scale, use_appearance_id=model_config.use_appearance_metadata, use_camera_id=model_config.use_camera_metadata, use_warp_id=model_config.use_warp, use_time=model_config.warp_metadata_encoder_type == 'time', random_seed=exp_config.random_seed, **exp_config.datasource_kwargs) train_iter = datasource.create_iterator( datasource.train_ids, flatten=True, shuffle=True, batch_size=train_config.batch_size, prefetch_size=3, shuffle_buffer_size=train_config.shuffle_buffer_size, devices=devices, ) points_iter = None if train_config.use_background_loss: points = datasource.load_points(shuffle=True) points_batch_size = min( len(points), len(devices) * train_config.background_points_batch_size) points_batch_size -= points_batch_size % len(devices) points_dataset = tf.data.Dataset.from_tensor_slices(points) points_iter = datasets.iterator_from_dataset( points_dataset, batch_size=points_batch_size, prefetch_size=3, devices=devices) learning_rate_sched = schedules.from_config(train_config.lr_schedule) warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule) time_alpha_sched = schedules.from_config(train_config.time_alpha_schedule) elastic_loss_weight_sched = schedules.from_config( train_config.elastic_loss_weight_schedule) rng, key = random.split(rng) params = {} model, params['model'] = models.construct_nerf( key, model_config, batch_size=train_config.batch_size, appearance_ids=datasource.appearance_ids, camera_ids=datasource.camera_ids, warp_ids=datasource.warp_ids, near=datasource.near, far=datasource.far, use_warp_jacobian=train_config.use_elastic_loss, use_weights=train_config.use_elastic_loss) optimizer_def = optim.Adam(learning_rate_sched(0)) optimizer = optimizer_def.create(params) state = model_utils.TrainState(optimizer=optimizer, warp_alpha=warp_alpha_sched(0), time_alpha=time_alpha_sched(0)) scalar_params = training.ScalarParams( learning_rate=learning_rate_sched(0), elastic_loss_weight=elastic_loss_weight_sched(0), warp_reg_loss_weight=train_config.warp_reg_loss_weight, warp_reg_loss_alpha=train_config.warp_reg_loss_alpha, warp_reg_loss_scale=train_config.warp_reg_loss_scale, background_loss_weight=train_config.background_loss_weight) state = checkpoints.restore_checkpoint(checkpoint_dir, state) init_step = state.optimizer.state.step + 1 state = jax_utils.replicate(state, devices=devices) del params logging.info('Initializing models') summary_writer = None if jax.process_index() == 0: summary_writer = tensorboard.SummaryWriter(str(summary_dir)) summary_writer.text('gin/train', textdata=gin.config.markdown(config_str), step=0) train_step = functools.partial( training.train_step, model, elastic_reduce_method=train_config.elastic_reduce_method, elastic_loss_type=train_config.elastic_loss_type, use_elastic_loss=train_config.use_elastic_loss, use_background_loss=train_config.use_background_loss, use_warp_reg_loss=train_config.use_warp_reg_loss, ) ptrain_step = jax.pmap( train_step, axis_name='batch', devices=devices, # rng_key, state, batch, scalar_params. in_axes=(0, 0, 0, None), # Treat use_elastic_loss as compile-time static. donate_argnums=(2, ), # Donate the 'batch' argument. ) if devices: n_local_devices = len(devices) else: n_local_devices = jax.local_device_count() logging.info('Starting training') rng = rng + jax.process_index() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) time_tracker = utils.TimeTracker() time_tracker.tic('data', 'total') for step, batch in zip(range(init_step, train_config.max_steps + 1), train_iter): if points_iter is not None: batch['background_points'] = next(points_iter) time_tracker.toc('data') # pytype: disable=attribute-error scalar_params = scalar_params.replace( learning_rate=learning_rate_sched(step), elastic_loss_weight=elastic_loss_weight_sched(step)) warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices) time_alpha = jax_utils.replicate(time_alpha_sched(step), devices) state = state.replace(warp_alpha=warp_alpha, time_alpha=time_alpha) with time_tracker.record_time('train_step'): state, stats, keys = ptrain_step(keys, state, batch, scalar_params) time_tracker.toc('total') if step % train_config.print_every == 0 and jax.process_index() == 0: logging.info('step=%d, warp_alpha=%.04f, time_alpha=%.04f, %s', step, warp_alpha_sched(step), time_alpha_sched(step), time_tracker.summary_str('last')) coarse_metrics_str = ', '.join( [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()]) fine_metrics_str = ', '.join( [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()]) logging.info('\tcoarse metrics: %s', coarse_metrics_str) if 'fine' in stats: logging.info('\tfine metrics: %s', fine_metrics_str) if step % train_config.save_every == 0 and jax.process_index() == 0: training.save_checkpoint(checkpoint_dir, state) if step % train_config.log_every == 0 and jax.process_index() == 0: # Only log via host 0. _log_to_tensorboard(summary_writer, jax_utils.unreplicate(state), scalar_params, jax_utils.unreplicate(stats), time_dict=time_tracker.summary('mean')) time_tracker.reset() if step % train_config.histogram_every == 0 and jax.process_index( ) == 0: _log_histograms(summary_writer, model, jax_utils.unreplicate(state)) time_tracker.tic('data', 'total') if train_config.max_steps % train_config.save_every != 0: training.save_checkpoint(checkpoint_dir, state)
def main(argv): del argv logging.info("*** Starting experiment") gin_configs = FLAGS.gin_configs logging.info("*** Loading Gin configs from: %s", str(gin_configs)) gin.parse_config_files_and_bindings(config_files=gin_configs, bindings=FLAGS.gin_bindings, skip_unknown=True) # Load configurations. exp_config = configs.ExperimentConfig() model_config = configs.ModelConfig(use_stratified_sampling=False) train_config = configs.TrainConfig() eval_config = configs.EvalConfig() # Get directory information. exp_dir = gpath.GPath(FLAGS.exp_dir) if exp_config.subname: exp_dir = exp_dir / exp_config.subname logging.info("\texp_dir = %s", exp_dir) if not exp_dir.exists(): exp_dir.mkdir(parents=True, exist_ok=True) summary_dir = exp_dir / "summaries" / "eval" logging.info("\tsummary_dir = %s", summary_dir) if not summary_dir.exists(): summary_dir.mkdir(parents=True, exist_ok=True) renders_dir = exp_dir / "renders" logging.info("\trenders_dir = %s", renders_dir) if not renders_dir.exists(): renders_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir = exp_dir / "checkpoints" logging.info("\tcheckpoint_dir = %s", checkpoint_dir) rng = random.PRNGKey(20200823) devices_to_use = jax.devices() n_devices = len( devices_to_use) if devices_to_use else jax.local_device_count() datasource_spec = exp_config.datasource_spec if datasource_spec is None: datasource_spec = { "type": exp_config.datasource_type, "data_dir": FLAGS.data_dir, } logging.info("Creating datasource: %s", datasource_spec) datasource = datasets.from_config( datasource_spec, image_scale=exp_config.image_scale, use_appearance_id=model_config.use_appearance_metadata, use_camera_id=model_config.use_camera_metadata, use_warp_id=model_config.use_warp, random_seed=exp_config.random_seed, ) # Get training IDs to evaluate. train_eval_ids = utils.strided_subset(datasource.train_ids, eval_config.num_train_eval) train_eval_iter = datasource.create_iterator(train_eval_ids, batch_size=0) val_eval_ids = utils.strided_subset(datasource.val_ids, eval_config.num_val_eval) val_eval_iter = datasource.create_iterator(val_eval_ids, batch_size=0) test_cameras = datasource.load_test_cameras( count=eval_config.num_test_eval) if test_cameras: test_dataset = datasource.create_cameras_dataset(test_cameras) test_eval_ids = [f"{x:03d}" for x in range(len(test_cameras))] test_eval_iter = datasets.iterator_from_dataset(test_dataset, batch_size=0) else: test_eval_ids = None test_eval_iter = None rng, key = random.split(rng) params = {} model, params["model"] = models.nerf( key, model_config, batch_size=eval_config.chunk, num_appearance_embeddings=len(datasource.appearance_ids), num_camera_embeddings=len(datasource.camera_ids), num_warp_embeddings=len(datasource.warp_ids), near=datasource.near, far=datasource.far, use_warp_jacobian=False, use_weights=False, ) optimizer_def = optim.Adam(0.0) optimizer = optimizer_def.create(params) init_state = model_utils.TrainState(optimizer=optimizer, warp_alpha=0.0) del params def _model_fn(key_0, key_1, params, rays_dict, alpha): out = model.apply( {"params": params}, rays_dict, warp_alpha=alpha, rngs={ "coarse": key_0, "fine": key_1 }, mutable=False, ) return jax.lax.all_gather(out, axis_name="batch") pmodel_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. _model_fn, in_axes=(0, 0, 0, 0, None), # Only distribute the data input. devices=devices_to_use, donate_argnums=(3, ), # Donate the 'rays' argument. axis_name="batch", ) render_fn = functools.partial( evaluation.render_image, model_fn=pmodel_fn, device_count=n_devices, chunk=eval_config.chunk, ) last_step = 0 summary_writer = tensorboard.SummaryWriter(str(summary_dir)) while True: if not checkpoint_dir.exists(): logging.info("No checkpoints yet.") time.sleep(10) continue state = checkpoints.restore_checkpoint(checkpoint_dir, init_state) state = jax_utils.replicate(state, devices=devices_to_use) step = int(state.optimizer.state.step[0]) if step <= last_step: logging.info("No new checkpoints (%d <= %d).", step, last_step) time.sleep(10) continue save_dir = renders_dir if eval_config.save_output else None process_iterator( tag="train", item_ids=train_eval_ids, iterator=train_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) process_iterator( tag="val", item_ids=val_eval_ids, iterator=val_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) if test_eval_iter: process_iterator( tag="test", item_ids=test_eval_ids, iterator=test_eval_iter, state=state, rng=rng, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) if eval_config.eval_once: break if step >= train_config.max_steps: break last_step = step
def main(argv): del argv logging.info("*** Starting experiment") gin_configs = FLAGS.gin_configs logging.info("*** Loading Gin configs from: %s", str(gin_configs)) gin.parse_config_files_and_bindings(config_files=gin_configs, bindings=FLAGS.gin_bindings, skip_unknown=True) # Load configurations. exp_config = configs.ExperimentConfig() model_config = configs.ModelConfig() train_config = configs.TrainConfig() # Get directory information. exp_dir = gpath.GPath(FLAGS.exp_dir) if exp_config.subname: exp_dir = exp_dir / exp_config.subname logging.info("exp_dir = %s", exp_dir) if not exp_dir.exists(): exp_dir.mkdir(parents=True, exist_ok=True) summary_dir = exp_dir / "summaries" / "train" logging.info("summary_dir = %s", summary_dir) if not summary_dir.exists(): summary_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir = exp_dir / "checkpoints" logging.info("checkpoint_dir = %s", checkpoint_dir) if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True, exist_ok=True) config_str = gin.operative_config_str() logging.info("Configuration: \n%s", config_str) with (exp_dir / "config.gin").open("w") as f: f.write(config_str) rng = random.PRNGKey(exp_config.random_seed) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(exp_config.random_seed + jax.host_id()) if train_config.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") devices = jax.devices() datasource_spec = exp_config.datasource_spec if datasource_spec is None: datasource_spec = { "type": exp_config.datasource_type, "data_dir": FLAGS.data_dir, } logging.info("Creating datasource: %s", datasource_spec) datasource = datasets.from_config( datasource_spec, image_scale=exp_config.image_scale, use_appearance_id=model_config.use_appearance_metadata, use_camera_id=model_config.use_camera_metadata, use_warp_id=model_config.use_warp, random_seed=exp_config.random_seed, ) train_dataset = datasource.create_dataset(datasource.train_ids, flatten=True, shuffle=True) train_iter = datasets.iterator_from_dataset( train_dataset, batch_size=train_config.batch_size, prefetch_size=3, devices=devices, ) points_iter = None if train_config.use_background_loss: points_dataset = (tf.data.Dataset.from_tensor_slices( datasource.load_points()).repeat().shuffle( 65536, reshuffle_each_iteration=True, seed=exp_config.random_seed)) points_iter = datasets.iterator_from_dataset( points_dataset, batch_size=len(devices) * train_config.background_points_batch_size, prefetch_size=3, devices=devices, ) learning_rate_sched = schedules.from_config(train_config.lr_schedule) warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule) elastic_loss_weight_sched = schedules.from_config( train_config.elastic_loss_weight_schedule) rng, key = random.split(rng) params = {} model, params["model"] = models.nerf( key, model_config, batch_size=train_config.batch_size, num_appearance_embeddings=len(datasource.appearance_ids), num_camera_embeddings=len(datasource.camera_ids), num_warp_embeddings=len(datasource.warp_ids), near=datasource.near, far=datasource.far, use_warp_jacobian=train_config.use_elastic_loss, use_weights=train_config.use_elastic_loss, ) optimizer_def = optim.Adam(learning_rate_sched(0)) optimizer = optimizer_def.create(params) state = model_utils.TrainState(optimizer=optimizer, warp_alpha=warp_alpha_sched(0)) scalar_params = training.ScalarParams( learning_rate=learning_rate_sched(0), elastic_loss_weight=elastic_loss_weight_sched(0), background_loss_weight=train_config.background_loss_weight, ) state = checkpoints.restore_checkpoint(checkpoint_dir, state) init_step = state.optimizer.state.step + 1 state = jax_utils.replicate(state, devices=devices) del params logging.info("Initializing models") summary_writer = None if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(str(summary_dir)) summary_writer.text( "gin/train", textdata=gin.config.markdownify_operative_config_str(config_str), step=0, ) train_step = functools.partial( training.train_step, model, elastic_reduce_method=train_config.elastic_reduce_method, use_elastic_loss=train_config.use_elastic_loss, use_background_loss=train_config.use_background_loss, ) ptrain_step = jax.pmap( train_step, axis_name="batch", devices=devices, # rng_key, state, batch, scalar_params. in_axes=(0, 0, 0, None), # Treat use_elastic_loss as compile-time static. donate_argnums=(2, ), # Donate the 'batch' argument. ) if devices: n_local_devices = len(devices) else: n_local_devices = jax.local_device_count() logging.info("Starting training") rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) time_tracker = utils.TimeTracker() time_tracker.tic("data", "total") for step, batch in zip(range(init_step, train_config.max_steps + 1), train_iter): if points_iter is not None: batch["background_points"] = next(points_iter) time_tracker.toc("data") # See: b/162398046. # pytype: disable=attribute-error scalar_params = scalar_params.replace( learning_rate=learning_rate_sched(step), elastic_loss_weight=elastic_loss_weight_sched(step), ) warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices) state = state.replace(warp_alpha=warp_alpha) with time_tracker.record_time("train_step"): state, stats, keys = ptrain_step(keys, state, batch, scalar_params) time_tracker.toc("total") if step % train_config.print_every == 0: logging.info( "step=%d, warp_alpha=%.04f, %s", step, warp_alpha_sched(step), time_tracker.summary_str("last"), ) coarse_metrics_str = ", ".join( [f"{k}={v.mean():.04f}" for k, v in stats["coarse"].items()]) fine_metrics_str = ", ".join( [f"{k}={v.mean():.04f}" for k, v in stats["fine"].items()]) logging.info("\tcoarse metrics: %s", coarse_metrics_str) if "fine" in stats: logging.info("\tfine metrics: %s", fine_metrics_str) if step % train_config.save_every == 0: training.save_checkpoint(checkpoint_dir, state) if step % train_config.log_every == 0 and jax.host_id() == 0: # Only log via host 0. _log_to_tensorboard( summary_writer, jax_utils.unreplicate(state), scalar_params, jax_utils.unreplicate(stats), time_dict=time_tracker.summary("mean"), ) time_tracker.reset() time_tracker.tic("data", "total") if train_config.max_steps % train_config.save_every != 0: training.save_checkpoint(checkpoint_dir, state)