def restore_model_and_put_to_devices( config: ml_collections.ConfigDict, workdir: str, partition_specs: Sequence[PartitionSpec], model: flax.nn.Module, optimizer: flax.optim.Optimizer, train_iter: Iterable[Any], rngs: Mapping[str, jnp.ndarray], thread_pool: multiprocessing.pool.ThreadPool, ) -> Tuple[flax.optim.Optimizer, Iterable[Any], jnp.ndarray, Mapping[str, Any]]: """Restores from latest available checkpoint and puts model to devices.""" (optimizer, train_iter, rng_state_tf, rngs, global_state) = train.restore_checkpoints( workdir=workdir, step=None, partition_specs=partition_specs, optimizer=optimizer, train_iter=train_iter, rng_state_tf=tf.random.get_global_generator().state.numpy(), rng_state_jax=rngs, global_state={}, thread_pool=thread_pool) if global_state: # 1. If a checkpoint is present in the current work dir, continue training. logging.info("Continuing training from step %d", global_state["step"]) # Shard parameters and optim state and put to the corresponding device. optimizer = core.tree_shard(optimizer) elif config.get("model_init_prefix"): # 2. Alternatively, initialize from the given model_init_prefix checkpoint. logging.info("Fine-tuning model from %r...", config.model_init_prefix) if not hasattr(model, "load"): # Note: Likely due to use of .partial, model may end up being e.g. # a flax.nn.Base.PatchTransformer instead of experts_nn.PatchTransformer # This causes explicit checks for class equivalence to fail, and also # causes static type checking to fail. Checking for .load attribute # circumvents both these issues. raise ValueError((f"Loaded model {model} has no load method. Are you sure" " it is one of 'PatchTransformer' and 'Resformer'?")) restored_params = model.load( prefix=config.model_init_prefix, init_params=optimizer.target, model_params=config.model, keep_head=config.get("keep_head", False), partition_specs=partition_specs) # Shard restored parameters and replicate original optimizer state. optimizer = optimizer.replace( target=core.tree_shard(restored_params), state=flax.jax_utils.replicate(optimizer.state)) global_state = {"step": 0, "accum_train_time": 0.0} else: # 3. Use model initialized from scratch. logging.info("Initializing training from scratch...") optimizer = flax.jax_utils.replicate(optimizer) global_state = {"step": 0, "accum_train_time": 0.0} # Set TF's global RNG generator and JAX's per-device RNG keys. train.rng_tf_set_global_generator(rng_state_tf) rngs_per_device = jax.tree_map(train.rng_jax_fold_host_if_needed_and_shard, rngs) return optimizer, train_iter, rngs_per_device, global_state
def init(rng): image_size = tuple(train_ds.element_spec["image"].shape[1:]) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] # Set bias in the head to a low value, such that loss is small initially. params["head"]["bias"] = jnp.full_like( params["head"]["bias"], config.get("init_head_bias", 0)) return params
def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} if config.get("mixup") and config.mixup.p: rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) def loss_fn(params, images, labels): logits, _ = model.apply( {"params": flax.core.freeze(params)}, images, train=True, rngs={"dropout": rng_model_local}) return getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get("grad_accum_steps")) l, g = jax.lax.pmean((l, g), axis_name="batch") # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements["l2_grads"] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get("grad_clip_norm"): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get("weight_decay", []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [(".*kernel.*", decay_rules)] sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=u.tree_map_with_regex( decay_fn, opt.target, decay_rules, name="weight decay")) params, _ = jax.tree_flatten(opt.target) measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements
def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_pipeline.get_num_examples( dataset, split, data_dir=fillin(data_dir)) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info("Running validation for %d steps for %s, %s", val_steps, dataset, split) val_it = input_pipeline.get_data( dataset=dataset, split=split, data_dir=fillin(data_dir), batch_size=local_batch_size_eval, preprocess_fn=pp_builder.get_preprocess_fn(pp_eval), cache=config.get("val_cache", "batched"), repeat_after_batching=True, prefetch=0, # Save memory since we cache. drop_remainder=False, shuffle_files=False) val_it = u.start_input_pipeline( val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval) return (val_it, val_steps)
def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, _ = model.apply({"params": flax.core.freeze(params)}, images) losses = getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses * mask, axis_name="batch") top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch") n = jax.lax.psum(mask, axis_name="batch") return ncorrect, loss, n
def main(_): config = flags.FLAGS.config workdir = flags.FLAGS.workdir tf.io.gfile.makedirs(workdir) partition_specs = [] # Loss to apply. loss_to_apply = getattr(core, config.get("loss_to_apply", "softmax_xent")) compute_ece = config.get("compute_ece", False) is_sigmoid = config.get("loss_to_apply", "softmax_xent") == "sigmoid_xent" if compute_ece and is_sigmoid: error_msg = "Inconsistent config: ECE can only be used with 'softmax_xent'." raise ValueError(error_msg) ens_size = config.get("model.transformer.ens_size", 1) # TODO(ghassen): enable sigmoid for ensemble. if ens_size > 1 and is_sigmoid: error_msg = "Inconsistent config: Ensemble only works with 'softmax_xent'." raise ValueError(error_msg) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() # Ideally, this should make code deterministic, but for many reasons we are # not there yet. For instance, tf.data.map is not determisntic. rng_generator = tf.random.Generator.from_seed(config.get("seed", 0)) tf.random.set_global_generator( rng_generator.split(jax.host_count())[jax.host_id()]) logging.info("Number of devices: %s (host_id: %s)", jax.device_count(), jax.host_id()) logging.info("Config:\n%s", str(config)) if (config.batch_size % jax.device_count() != 0 or config.batch_size_eval % jax.device_count() != 0): raise ValueError(f"Batch sizes ({config.batch_size} and " f"{config.batch_size_eval}) must be divisible by " f"the number of devices ({jax.device_count()})") batch_size_per_host = config.batch_size // jax.host_count() batch_size_per_core = config.batch_size // jax.device_count() batch_size_per_host_eval = config.batch_size_eval // jax.host_count() # TODO(basilm): Remove when JFT2.6B is properly submitted. if config.dataset in jft_latest_pipeline.DATA_INFO: input_pipeline = jft_latest_pipeline cache = "loaded" else: input_pipeline = default_input_pipeline cache = "batched" train_ds = input_pipeline.get_data( dataset=config.dataset, data_dir=config.get("dataset_dir"), split=config.train_split, batch_size=batch_size_per_host, preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), shuffle_buffer_size=config.shuffle_buffer_size, cache=False) steps_per_epoch = input_pipeline.get_num_examples( config.dataset, config.train_split, data_dir=config.get("dataset_dir")) / config.batch_size total_steps = train.get_total_steps_from_config(config, steps_per_epoch) logging.info("Running for %d steps per epoch (%d steps total)", steps_per_epoch, total_steps) opt_def = train.get_optimizer_from_config(config, f"{BIG_VISION_DIR}.optims") eval_config = copy.deepcopy(config) if config.get("eval_overrides"): with eval_config.unlocked(): eval_config.update(config.eval_overrides) model = getattr(ub.models, config.model_name) model_train = model( num_classes=config.num_classes, train=True, **config.model) model_eval = model( num_classes=config.num_classes, train=False, **eval_config.model) image_size = tuple(train_ds.element_spec["image"].shape[1:]) logging.info("Model initialization: Starting.") opt, rngs = train.model_and_optim_init( model_train.init, opt_def, (batch_size_per_core * ens_size,) + image_size, config.get("init_head_bias"), config.get("seed", 0), config.get("extra_rngs", ["dropout", "gating"])) logging.info("Model initialization: Done.") # TODO(jpuigcerver): Support logging parameter count with new sharding. if config.get("plot_grad_norm_patterns"): plot_grad_norm_name_fn = experts_utils.make_match_fn_from_prefixes( config.plot_grad_norm_patterns) else: plot_grad_norm_name_fn = None weight_decay_fn = train.get_weight_decay_function_from_config(config) batch_loss_fn = train.wrap_module_with_auxiliary_loss_fn( module=model_train, loss_fn=loss_to_apply, auxiliary_loss_weight=config.get("auxiliary_loss_weight", 0.0), ens_size=ens_size) if ens_size == 1: evaluation_fn = functools.partial( train.evaluation_fn, apply_fn=model_eval.apply, loss_fn=loss_to_apply, correct_fn=train.correct_multilabel, return_metric_args=compute_ece) else: evaluation_fn = functools.partial( ensemble.evaluation_fn, apply_fn=model_eval.apply, return_metric_args=compute_ece, ens_size=ens_size) pmap_evaluation_fn = core.pmap_sorted(evaluation_fn, axis_name="batch") update_fn = functools.partial( batchensemble_utils.update_fn_be, weight_decay_fn=weight_decay_fn, plot_grad_norm_name_fn=plot_grad_norm_name_fn, plot_grads_nan_inf=config.get("plot_grads_nan_inf", True), max_grad_norm_global=config.get("clip_grad_norm", None), frozen_vars_patterns=config.get("frozen_var_patterns", None), fast_weight_lr_multiplier=config.get("fast_weight_lr_multiplier", None)) pmap_update_fn = core.pmap_sorted( update_fn, axis_name="batch", donate_argnums=(0, 1), static_broadcasted_argnums=(5,)) # Restore parameters from checkpoints (if possible) and put to TPU devices. opt, train_iter, rngs_per_device, global_state = restore_model_and_put_to_devices( config, workdir, partition_specs, model, opt, iter(train_ds), rngs, pool) del rngs first_step = global_state["step"] accum_train_time = global_state["accum_train_time"] start_time = time.time() logging.info("Initial step for training = %d.", first_step) local_devices = sorted(jax.local_devices(), key=lambda device: device.id) if config.get("ema", {}): ema_updater = ema.ExponentialMovingAverage( target=partitioning.tree_unreplicate_using_partition_specs( jax.tree_map(np.zeros_like, opt.target), partition_specs=partition_specs, local_devices=local_devices), num_updates=0, **config.ema) else: ema_updater = None if first_step != 0 and ema_updater is not None: ema_updater = train.restore_ema_checkpoints( workdir, first_step, partition_specs, ema_updater, local_devices=local_devices, thread_pool=pool) train_iter = u.start_input_pipeline(train_iter, config.prefetch_to_device) eval_iters = train.get_dataset_eval_iters_from_config( config, batch_size_per_host_eval, cache, input_pipeline) lr_fn = u.create_learning_rate_schedule( config.batch_size, total_steps, steps_per_epoch, **config.lr) lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)), config.get("prefetch_to_device", 1)) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0, asynchronous_workers=128, summary_writer=config.get("write_tf_summaries", False)) checkpoint_async_results = [] log_training_first_n_steps = config.get("log_training_first_n_steps", -1) xm_work_unit.set_notes("First step compilations...") with metric_writers.ensure_flushes(writer): if jax.host_id() == 0: callback_fn = xprof.XmUrlCallbackFn( description=f"Xprof [{first_step + 171}...{first_step + 190}]", work_unit=xm_work_unit) else: callback_fn = lambda x: x # Do nothing. xprof_session = xprof.MultiStepXprofSession( profile_steps=20, # For how many steps to profile after warmup. warmup_steps=170, # For how many steps to wait before profiling. stop_callback_fn=callback_fn) for step, lr_repl in zip(range(first_step + 1, total_steps + 1), lr_iter): train_batch = next(train_iter) with xprof_session: with jax.profiler.StepTraceAnnotation(name="train", step_num=step): opt, rngs_per_device, loss_value, aux_info = pmap_update_fn( opt, rngs_per_device, lr_repl, train_batch["image"], train_batch["labels"], batch_loss_fn) if (ema_updater is not None and step % config.get("ema", {}).get("period", 10) == 0): ema_updater = ema_updater.update( partitioning.tree_unreplicate_using_partition_specs( tree=opt.target, partition_specs=partition_specs, local_devices=local_devices)) # Checkpoint saving. backup_checkpoints_every_n_steps = config.get("backup_checkpoint_steps") if (step % config.write_checkpoint_every_n_steps == 0 or (backup_checkpoints_every_n_steps is not None and step % backup_checkpoints_every_n_steps == 0) or step == total_steps): # Before writing new checkpoints, make sure that all the previous # checkpoint shards have been completely written (hosts are synced). train.wait_async_results( checkpoint_async_results, timeout_secs=config.checkpoint_write_timeout_secs) train.sync_all_hosts() # Now host 0 can remove all the checkpoints older than the previous # checkpointed step. The pool is used to remove files in parallel. if jax.host_id() == 0: train.remove_old_checkpoints( workdir, keep_steps_from=step - config.write_checkpoint_every_n_steps, keep_steps_multiple_of=backup_checkpoints_every_n_steps, thread_pool=pool) # Save checkpoint for the current step, asynchronously. # Note: Parameters on TPU are sliced and copied to CPU before scheduling # the asynchronous copy, to prevent any extra TPU memory usage. time_since_last_start = float(time.time() - start_time) checkpoint_async_results = train.save_checkpoints( workdir=workdir, step=step, partition_specs=partition_specs, optimizer=opt, # TODO(jpuigcerver): start_input_pipeline() does not return a # serializable iterator. Also, serialization of a "memory heavy" # tf.data.Dataset iterator may cause OOM (e.g. big shuffle buffer). train_iter=None, rng_state_tf=tf.random.get_global_generator().state.numpy(), rng_state_jax=rngs_per_device, global_state={ # Note: "step" is automatically added to this dictionary. "accum_train_time": accum_train_time + time_since_last_start, }, thread_pool=pool) if ema_updater is not None: checkpoint_async_results.append(train.save_ema_checkpoints( workdir=workdir, step=step, partition_specs=partition_specs, ema_updater=ema_updater, local_devices=local_devices, thread_pool=pool)) # Report training progress if (jax.host_id() == 0 and config.log_training_every_n_steps > 0 and (step % config.log_training_every_n_steps == 0 or step == total_steps or step < log_training_first_n_steps)): time_elapsed = time.time() - start_time + accum_train_time img_sec_core = ( config.batch_size * step / time_elapsed / jax.device_count()) writer.write_scalars(step, {"learning_rate": lr_repl[0], "training_loss": np.mean(loss_value), "img/sec/core": img_sec_core, "epoch": step / steps_per_epoch}) if aux_info: # Per-block info has to be dealt especially. if "per_block_info" in aux_info: scalar_metrics_to_aggregate = config.get( "scalar_metrics_to_aggregate", ()) metrics.write_info_to_metric_writer( metric_writer=writer, step=step, gating_info_dict=jax.tree_map(lambda x: np.mean(x, axis=0), aux_info["per_block_info"]), scalar_metrics_to_aggregate=scalar_metrics_to_aggregate, write_matrices=True) del aux_info["per_block_info"] # Plot rest of metrics as scalars. writer.write_scalars( step, {key: np.mean(value) for key, value in aux_info.items()}) def progress(start_time, step, num_steps, batch_size): """Generates progress note.""" time_elapsed = time.time() - start_time + accum_train_time steps_per_sec = step / time_elapsed eta_seconds = (num_steps - step) / (steps_per_sec + 1e-8) note = ("Steps:{:d}/{:d} [{:.1f}%]\n" "Images per second:{:.1f},\n" "ETA:{}, Total time:{}".format( step, num_steps, 100 * step / num_steps, steps_per_sec * batch_size, experts_utils.htime(eta_seconds), experts_utils.htime((eta_seconds + time_elapsed)))) writer.write_scalars( step, {"images_per_second": steps_per_sec * batch_size}) return note progress_note = progress( start_time, step, num_steps=total_steps, batch_size=config.batch_size) pool.apply_async( lambda note=progress_note: xm_work_unit.set_notes(note)) # Run checks to detect if the model partitioning is unhealthy. # Global health metrics will be written to XM, and in case of problems a # WARNING or ERROR message will be logged. train.monitor_partitioning_health( optimizer=opt, partition_specs=partition_specs, metric_writer=writer, step=step, first_step=first_step + 1, every_n_steps=config.get("check_partitioning_health_every_n_steps", total_steps // 20)) # Evaluate model on validation, test, ... rngs_per_device = train.run_evaluation_on_multiple_splits( pmap_evaluation_fn, opt.target, eval_iters, rngs_per_device, step / steps_per_epoch, step, total_steps, config.run_evaluation_every_n_steps, writer, compute_ece, config.get("ece_num_bins", 15), suffix="") if ema_updater and config.run_evaluation_every_n_steps > 0 and ( step == first_step + 1 or step % config.run_evaluation_every_n_steps == 0 or step == total_steps): logging.info("Evaluation with EMA weights at step %d: started.", step) # Copy current parameters to CPU. Only one replica of each local # partition is copied to prevent redundant data transfers (e.g. # non-expert parameters). curr_params = partitioning.tree_unreplicate_using_partition_specs( tree=opt.target, partition_specs=partition_specs, local_devices=local_devices) # Block curr_params until TPU->CPU copy has finished to prevent multiple # copies of the TPU parameters. curr_params = core.tree_block_until_ready(curr_params) # Allow TPU parameters to be freed. opt = opt.replace(target=None) # Copy EMA parameters to TPU and run evaluation. rngs_per_device = train.run_evaluation_on_multiple_splits( pmap_evaluation_fn, partitioning.tree_replicate_from_partitioned_tree( ema_updater.get(), partition_specs=partition_specs, local_devices=local_devices), eval_iters, rngs_per_device, step / steps_per_epoch, step, total_steps, config.run_evaluation_every_n_steps, writer, compute_ece, config.get("ece_num_bins", 15), suffix="_ema") rngs_per_device = core.tree_block_until_ready(rngs_per_device) # Copy current parameters back to the TPU. opt = opt.replace( target=partitioning.tree_replicate_from_partitioned_tree( curr_params, partition_specs=partition_specs, local_devices=local_devices)) logging.info("Evaluation with EMA weights at step %d: finished.", step) del curr_params pool.close() pool.join()
def main(argv): del argv config = FLAGS.config workdir = FLAGS.workdir logging.info("Workdir: %s", workdir) save_checkpoint_path = None if config.get("checkpoint_steps"): tf.io.gfile.makedirs(workdir) save_checkpoint_path = os.path.join(workdir, "checkpoint.npz") # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() # This seed makes the Jax part of things (like model init) deterministic. # However, full training still won't be deterministic, for example due to the # tf.data pipeline not being deterministic even if we would set TF seed. rng = jax.random.PRNGKey(config.get("seed", 0)) def write_note(note): if jax.host_id() == 0: logging.info("NOTE: %s", note) write_note("Initializing...") # Verify settings to make sure no checkpoints are accidentally missed. if config.get("keep_checkpoint_steps"): assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`." assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be" f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`") batch_size = config.batch_size batch_size_eval = config.get("batch_size_eval", batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must " f"be divisible by device number ({jax.device_count()})") local_batch_size = batch_size // jax.host_count() local_batch_size_eval = batch_size_eval // jax.host_count() logging.info( "Global batch size %d on %d hosts results in %d local batch size. " "With %d dev per host (%d dev total), that's a %d per-device batch size.", batch_size, jax.host_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note("Initializing train dataset...") train_ds = input_pipeline.get_data( dataset=config.dataset, split=config.train_split, data_dir=fillin(config.get("dataset_dir")), batch_size=local_batch_size, preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), shuffle_buffer_size=config.shuffle_buffer_size, prefetch=config.get("prefetch_to_host", 2), cache=False) # Start prefetching already. train_iter = u.start_input_pipeline( train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size) # We always pad to local_batch_size_eval even when less would be enough in # order to minimize memory fragmentation. write_note("Initializing val dataset(s)...") def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_pipeline.get_num_examples( dataset, split, data_dir=fillin(data_dir)) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info("Running validation for %d steps for %s, %s", val_steps, dataset, split) val_it = input_pipeline.get_data( dataset=dataset, split=split, data_dir=fillin(data_dir), batch_size=local_batch_size_eval, preprocess_fn=pp_builder.get_preprocess_fn(pp_eval), cache=config.get("val_cache", "batched"), repeat_after_batching=True, prefetch=0, # Save memory since we cache. drop_remainder=False, shuffle_files=False) val_it = u.start_input_pipeline( val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval) return (val_it, val_steps) if isinstance(config.val_split, str): val_ds = {"val": _get_val_split(config.dataset, config.val_split, config.pp_eval, config.get("dataset_dir"))} else: val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split} ntrain_img = input_pipeline.get_num_examples( config.dataset, config.train_split, data_dir=fillin(config.get("dataset_dir"))) steps_per_epoch = ntrain_img / batch_size if config.get("num_epochs"): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get("total_steps"), "Set either num_epochs or total_steps" else: total_steps = config.total_steps logging.info( "Running for %d steps, that means %f epochs and %f steps per epoch", total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch) write_note(f"Initializing {config.model_name} model...") model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}") model = model_mod.Model( num_classes=config.num_classes, **config.get("model", {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend="cpu") def init(rng): image_size = tuple(train_ds.element_spec["image"].shape[1:]) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] # Set bias in the head to a low value, such that loss is small initially. params["head"]["bias"] = jnp.full_like( params["head"]["bias"], config.get("init_head_bias", 0)) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.host_id() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) mw.measure("num_params", num_params) @partial(jax.pmap, axis_name="batch") def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, _ = model.apply({"params": flax.core.freeze(params)}, images) losses = getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses * mask, axis_name="batch") top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch") n = jax.lax.psum(mask, axis_name="batch") return ncorrect, loss, n # Setup function for computing representation. @partial(jax.pmap, axis_name="batch") def representation_fn(params, images, labels, mask): _, outputs = model.apply({"params": flax.core.freeze(params)}, images) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, "batch") labels = jax.lax.all_gather(labels, "batch") mask = jax.lax.all_gather(mask, "batch") return representation, labels, mask # Load the optimizer either from our folder or from flax. opt_name = config.get("optim_name", "momentum_hp") write_note(f"Initializing {opt_name} optimizer...") try: opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}") opt_def = opt_mod.Optimizer(**config.get("optim", {})) except ModuleNotFoundError: opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) @partial(jax.pmap, axis_name="batch", donate_argnums=(0,)) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} if config.get("mixup") and config.mixup.p: rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) def loss_fn(params, images, labels): logits, _ = model.apply( {"params": flax.core.freeze(params)}, images, train=True, rngs={"dropout": rng_model_local}) return getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get("grad_accum_steps")) l, g = jax.lax.pmean((l, g), axis_name="batch") # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements["l2_grads"] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get("grad_clip_norm"): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get("weight_decay", []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [(".*kernel.*", decay_rules)] sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=u.tree_map_with_regex( decay_fn, opt.target, decay_rules, name="weight decay")) params, _ = jax.tree_flatten(opt.target) measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements # Other things besides optimizer state to be stored. checkpoint_extra = dict(accum_train_time=0.0) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. # 3. Initialize model from something, e,g, start a fine-tuning job. # 4. Train from scratch. resume_checkpoint_path = None if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path): resume_checkpoint_path = save_checkpoint_path elif config.get("resume"): resume_checkpoint_path = fillin(config.resume) if resume_checkpoint_path: write_note("Resume training from checkpoint...") checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra} _, checkpoint_tree = jax.tree_flatten(checkpoint) loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path) # bfloat16 type gets lost when data is saved to disk, so we recover it. checkpoint = jax.tree_map(u.recover_dtype, loaded) opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"] elif config.get("model_init"): write_note(f"Initialize model from {config.model_init}...") loaded = model_mod.load(params_cpu, config.model_init, config.get("model")) opt_cpu = opt_cpu.replace(target=loaded) if jax.host_id() == 0: logging.info("Restored parameter overview:") parameter_overview.log_parameter_overview(loaded) write_note("Kicking off misc stuff...") first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. chrono = u.Chrono(first_step, total_steps, batch_size, checkpoint_extra["accum_train_time"]) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=workdir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = u.create_learning_rate_schedule( batch_size, total_steps, steps_per_epoch, **config.get("lr", {})) lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)), config.get("prefetch_to_device", 1)) write_note(f"Replicating...\n{chrono.note}") opt_repl = flax_utils.replicate(opt_cpu) write_note(f"Initializing few-shotters...\n{chrono.note}") if "fewshot" in config: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get("batch_size") or batch_size_eval) rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax_utils.replicate(rng_loop) checkpoint_writer = None write_note(f"First step compilations...\n{chrono.note}") # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): mw.step_start(step) with jax.profiler.TraceContext("train_step", step_num=step, _r=1): opt_repl, loss_value, rngs_loop, extra_measurements = update_fn( opt_repl, lr_repl, train_batch["image"], train_batch["labels"], rng=rngs_loop) if jax.host_id() == 0: profiler(step) # Checkpoint saving if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0): chrono.pause() u.checkpointing_timeout(checkpoint_writer, config.get("checkpoint_timeout", 1)) checkpoint_extra["accum_train_time"] = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps): copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra} checkpoint_writer = pool.apply_async( u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if u.itstime(step, config.log_training_steps, total_steps, host=0): mw.measure("learning_rate", lr_repl[0]) mw.measure("training_loss", loss_value[0]) for name, value in extra_measurements.items(): mw.measure(name, value[0]) chrono.tick(step, mw.measure, write_note) # Report validation performance if u.itstime(step, config.log_eval_steps, total_steps): chrono.pause() for val_name, (val_iter, val_steps) in val_ds.items(): ncorrect, loss, nseen = 0, 0, 0 for _, batch in zip(range(val_steps), val_iter): batch_ncorrect, batch_losses, batch_n = evaluation_fn( opt_repl.target, batch["image"], batch["labels"], batch["mask"]) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) mw.measure(f"{val_name}_prec@1", ncorrect / nseen) mw.measure(f"{val_name}_loss", loss / nseen) chrono.resume() if "fewshot" in config: # Compute few-shot on-the-fly evaluation. if u.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f"Few-shot evaluation...\n{chrono.note}") r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets) fewshotter.walk_results(mw.measure, *r) chrono.resume() mw.step_end() write_note(f"Done!\n{chrono.note}") pool.close() pool.join() mw.close()
def loss_fn(params, images, labels): logits, _ = model.apply( {"params": flax.core.freeze(params)}, images, train=True, rngs={"dropout": rng_model_local}) return getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels)