def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ if config.dataset.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") tf.io.gfile.makedirs(workdir) # Deterministic training. rng = jax.random.PRNGKey(config.seed) # Shift the numpy random seed by process_index() to shuffle data loaded # by different hosts np.random.seed(20201473 + jax.process_index()) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, eval_ds = datasets.create_dataset(config) example_batch = train_ds.peek() #---------------------------------------------------------------------------- # Learning rate schedule. num_train_steps = config.train.max_steps if num_train_steps == -1: num_train_steps = train_ds.size() steps_per_epoch = num_train_steps // config.train.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) learning_rate_fn = train_utils.create_learning_rate_fn(config) #---------------------------------------------------------------------------- # Initialize model. rng, model_rng = jax.random.split(rng) model, state = models.create_train_state( config, model_rng, learning_rate_fn=learning_rate_fn, example_batch=example_batch, ) #---------------------------------------------------------------------------- # Set up checkpointing of the model and the input pipeline. state = checkpoints.restore_checkpoint(workdir, state) initial_step = int(state.step) + 1 #---------------------------------------------------------------------------- # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.train.weight_decay, config=config, ), axis_name="batch", ) # Get distributed rendering function render_pfn = render_utils.get_render_function( model=model, config=config, randomized=False, # No randomization for evaluation. ) #---------------------------------------------------------------------------- # Prepare Metric Writers writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, ] train_metrics = None # Prefetch_buffer_size = 6 x batch_size ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) n_local_devices = jax.local_device_count() rng = rng + jax.process_index() # Make random seed separate across hosts. keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys. with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(ptrain_ds) state, metrics_update, keys = p_train_step(rng=keys, state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.train.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None if step % config.train.render_every_steps == 0 or is_last_step: test_batch = next(eval_ds) test_pixels = model_utils.uint2float( test_batch.target_view.rgb) # extract for evaluation with report_progress.timed("eval"): pred_color, pred_disp, pred_acc = eval_step( state, keys[0], test_batch, render_pfn, config) #------------------------------------------------------------------ # Log metrics and images for host 0 #------------------------------------------------------------------ if jax.process_index() == 0: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weight=True) writer.write_scalars( step, { "train_eval/test_psnr": psnr, "train_eval/test_ssim": ssim, }) writer.write_images( step, { "test_pred_color": pred_color[None, :], "test_target": test_pixels[None, :] }) if pred_disp is not None: writer.write_images( step, {"test_pred_disp": pred_disp[None, :]}) if pred_acc is not None: writer.write_images( step, {"test_pred_acc": pred_acc[None, :]}) #------------------------------------------------------------------ if (jax.process_index() == 0) and (step % config.train.checkpoint_every_steps == 0 or is_last_step): # Write final metrics to file with file_utils.open_file( os.path.join(workdir, "train_logs.json"), "w") as f: log_dict = metric_update.compute() for k, v in log_dict.items(): log_dict[k] = v.item() f.write(json.dumps(log_dict)) with report_progress.timed("checkpoint"): state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100) logging.info("Finishing training at step %d", num_train_steps)
def main(config, output_dir): seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.host_id() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.host_count() local_batch_size_eval = batch_size_eval // jax.host_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.host_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, split=config.val_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = int(ntrain_img / batch_size) if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) model = ub.models.bit_resnet(num_classes=config.num_classes, **config.get('model', {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['head']['bias'] = jnp.full_like(params['head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['head']['kernel'] = jnp.full_like(params['head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.host_id() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): _, outputs = model.apply({'params': flax.core.freeze(params)}, images, train=False) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) accuracy = jnp.mean( jnp.equal(jnp.argmax(logits, axis=-1), jnp.argmax(labels, axis=-1))) return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels), accuracy grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (l, train_accuracy), g = grad_fn(opt.target, images, labels) l, g = jax.lax.pmean((l, g), axis_name='batch') measurements['accuracy'] = train_accuracy # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get('weight_decay', []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [('.*kernel.*', decay_rules)] sched_m = lr / config.lr.base if config.get( 'weight_decay_decouple') else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=train_utils.tree_map_with_regex( decay_fn, opt.target, decay_rules)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements # Other things besides optimizer state to be stored. rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax_utils.replicate(rng_loop) checkpoint_extra = dict(accum_train_time=0.0, rngs_loop=rngs_loop) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. # 3. Initialize model from something, e,g, start a fine-tuning job. # 4. Train from scratch. resume_checkpoint_path = None if save_checkpoint_path and gfile.exists(save_checkpoint_path): resume_checkpoint_path = save_checkpoint_path elif config.get('resume'): resume_checkpoint_path = config.resume if resume_checkpoint_path: write_note('Resume training from checkpoint...') checkpoint_tree = {'opt': opt_cpu, 'extra': checkpoint_extra} checkpoint = checkpoint_utils.load_checkpoint(checkpoint_tree, resume_checkpoint_path) opt_cpu, checkpoint_extra = checkpoint['opt'], checkpoint['extra'] rngs_loop = checkpoint_extra['rngs_loop'] elif config.get('model_init'): write_note(f'Initialize model from {config.model_init}...') reinit_params = config.get('model_reinit_params', ('head/kernel', 'head/bias')) logging.info('Reinitializing these parameters: %s', reinit_params) # We only support "no head" fine-tuning for now. loaded_params = checkpoint_utils.load_checkpoint( tree=None, path=config.model_init) loaded = checkpoint_utils.restore_from_pretrained_params( params_cpu, loaded_params, model_representation_size=None, model_classifier=None, reinit_params=reinit_params) opt_cpu = opt_cpu.replace(target=loaded) if jax.host_id() == 0: logging.info('Restored parameter overview:') parameter_overview.log_parameter_overview(loaded) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.host_id() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, checkpoint_extra['accum_train_time']) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceContext('train_step', step_num=step, _r=1): opt_repl, loss_value, rngs_loop, extra_measurements = update_fn( opt_repl, lr_repl, train_batch['image'], train_batch['labels'], rng=rngs_loop) if jax.host_id() == 0: profiler(step) # Checkpoint saving if train_utils.itstime(step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) checkpoint_extra['accum_train_time'] = chrono.accum_train_time checkpoint_extra['rngs_loop'] = rngs_loop # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint = {'opt': opt_cpu, 'extra': checkpoint_extra} checkpoint_writer = pool.apply_async( checkpoint_utils.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime(step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_accuracy = extra_measurements['accuracy'] train_accuracy = jnp.mean(train_accuracy) train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, 'training_accuracy': train_accuracy, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip( probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state( batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[ f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[ f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # Entries in the ood_ds dict include: # (ind_dataset, ood_dataset1, ood_dataset2, ...). # OOD metrics are computed using ind_dataset paired with each of the # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds # is also included in the ood_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, opt_repl.target, n_prefetch=config.get('prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs training interleaved with evaluation.""" # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') ds_train, ds_test = input_pipeline.get_datasets(config) batch = next(iter(ds_train)) logging.info(ds_train) logging.info(ds_test) # Build VisionTransformer architecture model_cls = {'ViT': models.VisionTransformer, 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] model = model_cls(num_classes=dataset_info['num_classes'], **config.model) def init_model(): return model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), train=False) # Use JIT to make sure params reside in CPU memory. variables = jax.jit(init_model, backend='cpu')() model_or_filename = config.get('model_or_filename') if model_or_filename: # Loading model from repo published with "How to train your ViT? Data, # Augmentation, and Regularization in Vision Transformers" paper. # https://arxiv.org/abs/2106.10270 if '-' in model_or_filename: filename = model_or_filename else: # Select best checkpoint from i21k pretraining by final upstream # validation accuracy. df = checkpoint.get_augreg_df(directory=config.pretrained_dir) sel = df.filename.apply( lambda filename: filename.split('-')[0] == model_or_filename) best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1] filename = best.filename logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', filename, model_or_filename, best.final_val) pretrained_path = os.path.join(config.pretrained_dir, f'{config.model.model_name}.npz') else: # ViT / Mixer papers filename = config.model.model_name pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz') if not tf.io.gfile.exists(pretrained_path): raise ValueError( f'Could not find "{pretrained_path}" - you can download models from ' '"gs://vit_models/imagenet21k" or directly set ' '--config.pretrained_dir="gs://vit_models/imagenet21k".') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=variables['params'], model_config=config.model) total_steps = config.total_steps lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) update_fn_repl = make_update_fn( apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=config.optim_dtype, grad_norm_clip=config.grad_norm_clip).create(params) initial_step = 1 opt, initial_step = flax_checkpoints.restore_checkpoint( workdir, (opt, initial_step)) logging.info('Will start/continue training at initial_step=%d', initial_step) opt_repl = flax.jax_utils.replicate(opt) # Delete references to the objects that are not needed anymore del opt del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # Setup metric writer & hooks. writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) hooks = [ periodic_actions.Profile(logdir=workdir), periodic_actions.ReportProgress( num_train_steps=total_steps, writer=writer), ] # Run training loop logging.info('Starting training loop; initial compile can take a while...') t0 = lt0 = time.time() lstep = initial_step for step, batch in zip( range(initial_step, total_steps + 1), input_pipeline.prefetch(ds_train, config.prefetch)): with jax.profiler.StepTraceAnnotation('train', step_num=step): opt_repl, loss_repl, update_rng_repl = update_fn_repl( opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) for hook in hooks: hook(step) if step == initial_step: logging.info('First step took %.1f seconds.', time.time() - t0) t0 = time.time() lt0, lstep = time.time(), step # Report training metrics if config.progress_every and step % config.progress_every == 0: img_sec_core_train = (config.batch * (step - lstep) / (time.time() - lt0)) / jax.device_count() lt0, lstep = time.time(), step writer.write_scalars( step, dict( train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') # Run evaluation if ((config.eval_every and step % config.eval_every == 0) or (step == total_steps)): accuracies = [] lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=opt_repl.target), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], axis=-1)).mean()) accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / (time.time() - lt0) / jax.device_count()) lt0 = time.time() lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') writer.write_scalars( step, dict( accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(opt_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) return flax.jax_utils.unreplicate(opt_repl)
def main(args): logdir = os.path.join(args.logdir, args.name) logger = logging.setup_logger(logdir) logger.info(args) logger.info(f'Available devices: {jax.devices()}') # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train') ds_train = input_pipeline.get_data( dataset=args.dataset, mode='train', repeats=None, mixup_alpha=args.mixup_alpha, batch_size=args.batch, shuffle_buffer=args.shuffle_buffer, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) batch = next(iter(ds_train)) logger.info(ds_train) ds_test = input_pipeline.get_data( dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch_eval, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) logger.info(ds_test) # Build VisionTransformer architecture model = models.KNOWN_MODELS[args.model] VisionTransformer = model.partial(num_classes=dataset_info['num_classes']) _, params = VisionTransformer.init_by_shape( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. [(batch['image'].shape[1:], batch['image'].dtype.name)]) pretrained_path = os.path.join(args.vit_pretrained_dir, f'{args.model}.npz') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=params, model_config=models.CONFIGS[args.model], logger=logger) # pmap replicates the models over all TPUs/GPUs vit_fn_repl = jax.pmap(VisionTransformer.call) update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=args.optim_dtype, grad_norm_clip=args.grad_norm_clip).create(params) opt_repl = flax_utils.replicate(opt) # Delete referenes to the objects that are not needed anymore del opt del params def copyfiles(paths): """Small helper to copy files to args.copy_to using tf.io.gfile.""" if not args.copy_to: return for path in paths: to_path = os.path.join(args.copy_to, args.name, os.path.basename(path)) tf.io.gfile.makedirs(os.path.dirname(to_path)) tf.io.gfile.copy(path, to_path, overwrite=True) logger.info(f'Copied {path} to {to_path}.') total_steps = args.total_steps or ( input_pipeline.DATASET_PRESETS[args.dataset]['total_steps']) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr, args.decay_type, args.warmup_steps) lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps) update_rngs = jax.random.split( jax.random.PRNGKey(0), jax.local_device_count()) # Run training loop writer = metric_writers.create_default_writer(logdir, asynchronous=False) writer.write_hparams({k: v for k, v in vars(args).items() if v is not None}) logger.info('Starting training loop; initial compile can take a while...') t0 = time.time() for step, batch, lr_repl in zip( range(1, total_steps + 1), input_pipeline.prefetch(ds_train, args.prefetch), lr_iter): opt_repl, loss_repl, update_rngs = update_fn_repl( opt_repl, lr_repl, batch, update_rngs) if step == 1: logger.info(f'First step took {time.time() - t0:.1f} seconds.') t0 = time.time() if args.progress_every and step % args.progress_every == 0: writer.write_scalars(step, dict(train_loss=float(loss_repl[0]))) done = step / total_steps logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') copyfiles(glob.glob(f'{logdir}/*')) # Run eval step if ((args.eval_every and step % args.eval_every == 0) or (step == total_steps)): accuracy_test = np.mean([ c for batch in input_pipeline.prefetch(ds_test, args.prefetch) for c in ( np.argmax(vit_fn_repl(opt_repl.target, batch['image']), axis=2) == np.argmax(batch['label'], axis=2)).ravel() ]) lr = float(lr_repl[0]) logger.info(f'Step: {step} ' f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}') writer.write_scalars(step, dict(accuracy_test=accuracy_test, lr=lr)) copyfiles(glob.glob(f'{logdir}/*')) if args.output: checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output) logger.info(f'Stored fine tuned checkpoint to {args.output}') copyfiles([args.output])
def train(config, model_def, device_batch_size, eval_ds, num_steps, steps_per_epoch, steps_per_eval, train_ds, image_size, data_source, workdir): """Train model.""" make_lr_fn = schedulers.get_make_lr_fn(config) make_temp_fn = schedulers.get_make_temp_fn(config) make_step_size_fn = schedulers.get_make_step_size_fn(config) if jax.host_count() > 1: raise ValueError('CIFAR10 example should not be run on ' 'more than 1 host due to preconditioner updating.') initial_step = 0 # TODO(basv): load from checkpoint. writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs in TensorBoard easier. # with writer.summary_writer.as_default(): writer.write_hparams(dict(config)) rng = random.PRNGKey(config.seed) rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4) base_learning_rate = config.learning_rate # Create the model. model, state = create_model(rng, device_batch_size, image_size, model_def) parameter_overview.log_parameter_overview(model.params) state = jax_utils.replicate(state) train_size = data_source.TRAIN_IMAGES with flax.deprecated.nn.stochastic(init_key): optimizer = create_optimizer(config, model, base_learning_rate, train_size, sampler_rng) del model # Don't keep a copy of the initial model. # Learning rate schedule learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch) temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch) step_size_fn = make_step_size_fn(steps_per_epoch) p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions( config, config.l2_reg, learning_rate_fn, train_size, temperature_fn, step_size_fn) # Create dataset batch iterators. train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Gather metrics. train_metrics = [] epoch = 0 # Ensemble. ensemble = [] ensemble_logits = [] ensemble_labels = [] ensemble_probs = [] def ensemble_add_step(step): if config.lr_schedule == 'cosine': # Add if learning rate jumps up again in the next step. increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8 _, temp_end = ast.literal_eval(config.temp_ramp) past_burn_in = step >= steps_per_epoch * temp_end return increase and past_burn_in elif config.lr_schedule == 'constant': if (step + 1) % steps_per_epoch == 0: return True return False logging.info('Starting training loop at step %d.', initial_step) for step in range(initial_step, num_steps): if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0: optimizer, rng = update_preconditioner(config, optimizer, p_update_grad_vars, rng, state, train_iter) # Generate a PRNG key that will be rolled into the batch step_key = jax.random.fold_in(rng, step) opt_step_rng = jax.random.fold_in(opt_rng, step) # Load and shard the TF batch batch = next(train_iter) batch = input_pipeline.load_and_shard_tf_batch(config, batch) if not config.debug_run: # Shard the step PRNG key # Don't shard the optimizer rng, as it should be equal among all machines. sharded_keys = common_utils.shard_prng_key(step_key) else: sharded_keys = step_key # Train step optimizer, state, metrics = p_train_step(optimizer, state, batch, sharded_keys, opt_step_rng) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) if step == initial_step: initial_train_metrics = get_metrics(config, train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics) train_summary = {'train_' + k: v for k, v in train_summary.items()} logging.log(logging.INFO, 'initial metrics = %s', str(train_summary.items())) if (step + 1) % steps_per_epoch == 0: # We've finished an epoch # Save model params/state. train_metrics = get_metrics(config, train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary = {'train_' + k: v for k, v in train_summary.items()} writer.write_scalars(epoch, train_summary) # Reset train metrics train_metrics = [] # Evaluation if config.do_eval: eval_metrics = [] eval_logits = [] eval_labels = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = input_pipeline.load_and_shard_tf_batch( config, eval_batch) # Step logits, labels, metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_logits.append(logits) eval_labels.append(labels) eval_metrics = get_metrics(config, eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) eval_summary = {'eval_' + k: v for k, v in eval_summary.items()} writer.write_scalars(epoch, eval_summary) if config.algorithm == 'sgmcmc' and ensemble_add_step(step): ensemble.append((serialization.to_state_dict(optimizer.target), state)) if config.algorithm == 'sgmcmc' and ensemble_add_step( step) and len(ensemble) >= 1: # Gather predictions for this ensemble sample. eval_logits = jnp.concatenate(eval_logits, axis=0) eval_probs = jax.nn.softmax(eval_logits, axis=-1) eval_labels = jnp.concatenate(eval_labels, axis=0) # Ensure that labels are consistent between predict runs. if ensemble_labels: assert jnp.allclose( eval_labels, ensemble_labels[0]), 'Labels unordered between eval runs.' ensemble_logits.append(eval_logits) ensemble_probs.append(eval_probs) ensemble_labels.append(eval_labels) # Compute ensemble predictions over last config.ensemble_size samples. ensemble_last_probs = jnp.mean( jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0) ensemble_metrics = train_functions.compute_metrics_probs( ensemble_last_probs, ensemble_labels[0]) ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics) ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()} ensemble_summary['ensemble_size'] = min(config.ensemble_size, len(ensemble_probs)) writer.write_scalars(epoch, ensemble_summary) epoch += 1 return ensemble, optimizer
def main(argv): del argv # unused arg config = FLAGS.config # Unpack total and warmup steps # TODO(nband): revert this to separate arguments. total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift) # LR / Optimization Flags batch_size = config.batch_size grad_clip_norm = config.grad_clip_norm weight_decay = config.weight_decay print('Standard wandb hyperparameters:') print({ 'batch_size': batch_size, 'grad_clip_norm': grad_clip_norm, 'weight_decay': weight_decay, 'total_steps': config.total_steps, 'lr': config.lr }) print('SNGP Params:', config.gp_layer) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.seed rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): tf.io.gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d dev per host (%d dev total), that is a %d per-device batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get( dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint: disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:]) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img / batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get('total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) # Specify Gaussian process layer configs. gp_config = config.get('gp_layer', {}) model_dict = vit_utils.initialize_model('sngp', config) model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) variables = model.init(rng, dummy_input, train=False) # Split model parameters into trainable and untrainable collections. states, params = variables.pop('params') del variables # Set bias in the head to a low value, such that loss is small initially. params = flax.core.unfreeze(params) if use_gp_layer: # Modify the head parameter in the GP head. params['head']['output_layer']['bias'] = jnp.full_like( params['head']['output_layer']['bias'], config.get('init_head_bias', 0)) else: params['head']['bias'] = jnp.full_like( params['head']['bias'], config.get('init_head_bias', 0)) return params, states rng, rng_init = jax.random.split(rng) params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, states, images, labels): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply( variable_dict, images, train=False, mean_field_factor=gp_config.get('mean_field_factor', -1.)) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ logits, labels, out['pre_logits']], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,)) def update_fn(opt, states, lr, reset_covmat, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, states, images, labels): # Specify mutable collection to update untrainable GP parameters. variable_dict = {'params': flax.core.freeze(params), **states} model_results, updated_states = model.apply( variable_dict, images, train=True, rngs={'dropout': rng_model_local}, mutable=list(states.keys()), mean_field_factor=gp_config.get('mean_field_factor', -1.)) logits, _ = model_results loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels) return loss, updated_states # Performs exact covariance update (i.e., reset precision matrix resetting # at begining of new epoch) if covmat_momentum is a null value. if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0: # Resets precision matrix to Identity * ridge_penalty if at the begining # of a new epoch. This should be done before accumulate gradient. ridge_penalty = gp_config.get('ridge_penalty', 1.) prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] prec_mat_new = ( (1. - reset_covmat) * prec_mat_old + reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty) states = flax.core.unfreeze(states) states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] = prec_mat_new states = flax.core.freeze(states) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, s), g = vit_utils.accumulate_gradient_with_states( jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat return opt, s, l, rng, measurements # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias', 'head/kernel', 'head/bias') rng, train_loop_rngs = jax.random.split(rng) checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=states_cpu, default_reinit_params=default_reinit_params, config=config) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer states_cpu = checkpoint_data.fixed_model_states accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar( map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) # Prepare the precision matrix resetting schedule, and pre-fetch it to device. reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0) reset_covmat_iter = train_utils.prefetch_scalar( map(reset_covmat_fn, range(first_step, total_steps)), nprefetch=config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) states_repl = flax.jax_utils.replicate(states_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # val_loss = -jnp.inf # results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. # Makes sure log_eval_steps is same as steps_per_epoch. This is because # the precision matrix needs to be updated fully (at the end of each epoch) # when eval takes place. log_eval_steps = steps_per_epoch if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl, reset_covmat_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter, reset_covmat_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, train_loop_rngs, extra_measurements) = update_fn( opt_repl, states_repl, lr_repl, reset_covmat_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout(checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # For GP layer, we will also do the same for untrainable parameters # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, fixed_model_states=states_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update(flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn( opt_repl.target, states_repl, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True ) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results( output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() if wandb_run is not None: wandb_run.finish()
def train(*, workdir, compute_phi, compute_psi, params, optimal_subspace, num_epochs, learning_rate, key, method, lissa_kappa, optimizer, covariance_batch_size, main_batch_size, weight_batch_size, d, num_tasks, compute_feature_norm_on_oracle_states, sample_states, eval_states, use_tabular_gradient=True): """Training function. For lissa, the total number of samples is 2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size. Args: workdir: Work directory, where we'll save logs. compute_phi: A function that takes params and states and returns a matrix of phis. compute_psi: A function that takes an array of states and an array of tasks and returns Psi[states, tasks]. params: Parameters used as the first argument for compute_phi. optimal_subspace: Top-d left singular vectors of Psi. num_epochs: How many gradient steps to perform. (Not really epochs) learning_rate: The step size parameter for sgd. key: The jax prng key. method: 'naive', 'lissa', or 'oracle'. lissa_kappa: The parameter of the lissa method, if used. optimizer: Which optimizer to use. Only 'sgd' is supported. covariance_batch_size: the 'J' parameter. For the naive method, this is how many states we sample to construct the inverse. For the lissa method, ditto -- these are also "iterations". main_batch_size: How many states to update at once. weight_batch_size: How many states to construct the weight vector. d: The dimension of the representation. num_tasks: The total number of tasks. compute_feature_norm_on_oracle_states: If True, computes the feature norm using the oracle states (all the states in synthetic experiments). Otherwise, computes the norm using the sampled batch. Only applies to LISSA. sample_states: A function that takes an rng key and a number of states to sample, and returns a tuple containing (a vector of sampled states, an updated rng key). eval_states: An array of states to use to compute metrics on. This will be used to compute Phi = compute_phi(params, eval_states). use_tabular_gradient: If true, the train step will calculate the gradient using the tabular calculation. Otherwise, it will use a jax.vjp to backpropagate the gradient. """ # Create an explicit weight vector (needed for explicit method only). if method == 'explicit': key, weight_key = jax.random.split(key) explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks), dtype=jnp.float32) params['explicit_weight_matrix'] = explicit_weight_matrix if optimizer == 'sgd': optimizer = optax.sgd(learning_rate) elif optimizer == 'adam': optimizer = optax.adam(learning_rate) else: raise ValueError(f'Unknown optimizer {optimizer}.') optimizer_state = optimizer.init(params) chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value) initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize( (0, params, optimizer_state)) writer = metric_writers.create_default_writer(logdir=str(workdir), ) # Checkpointing and logging too much can use a lot of disk space. # Therefore, we don't want to checkpoint more than 10 times an experiment, # or keep more than 1k Phis per experiment. checkpoint_period = max(num_epochs // 10, 100_000) log_period = max(1_000, num_epochs // 1_000) def _checkpoint_callback(step, t, params, optimizer_state): del t # Unused. chkpt_manager.save((step, params, optimizer_state)) hooks = [ periodic_actions.PeriodicCallback(every_steps=checkpoint_period, callback_fn=_checkpoint_callback) ] fixed_train_kwargs = { 'compute_phi': compute_phi, 'compute_psi': compute_psi, 'optimizer': optimizer, 'method': method, # In the tabular case, the eval_states are all the states. 'oracle_states': eval_states, 'lissa_kappa': lissa_kappa, 'main_batch_size': main_batch_size, 'covariance_batch_size': covariance_batch_size, 'weight_batch_size': weight_batch_size, 'd': d, 'num_tasks': num_tasks, 'compute_feature_norm_on_oracle_states': (compute_feature_norm_on_oracle_states), 'sample_states': sample_states, 'use_tabular_gradient': use_tabular_gradient, } variable_kwargs = { 'params': params, 'optimizer_state': optimizer_state, 'key': key, } @jax.jit def _eval_step(phi_params): eval_phi = compute_phi(phi_params, eval_states) eval_psi = compute_psi(eval_states) # pytype: disable=wrong-arg-count metrics = compute_metrics(eval_phi, optimal_subspace) metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)} return metrics # Perform num_epochs gradient steps. with metric_writers.ensure_flushes(writer): for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1), initial=initial_step, total=num_epochs): variable_kwargs = _train_step(**fixed_train_kwargs, **variable_kwargs) if step % log_period == 0: metrics = _eval_step(variable_kwargs['params']['phi_params']) writer.write_scalars(step, metrics) for hook in hooks: hook(step, params=variable_kwargs['params'], optimizer_state=variable_kwargs['optimizer_state']) writer.flush()
def main(unused_argv): logging.info("Loading %s", FLAGS.game_name) game = pyspiel.load_game(FLAGS.game_name, GAME_SETTINGS.get(FLAGS.game_name, {})) uniform_policy = policy.UniformRandomPolicy(game) mfg_dist = distribution.DistributionPolicy(game, uniform_policy) envs = [ rl_environment.Environment(game, mfg_distribution=mfg_dist, mfg_population=p) for p in range(game.num_players()) ] info_state_size = envs[0].observation_spec()["info_state"][0] num_actions = envs[0].action_spec()["num_actions"] hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes] kwargs = { "replay_buffer_capacity": FLAGS.replay_buffer_capacity, "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn, "batch_size": FLAGS.batch_size, "learn_every": FLAGS.learn_every, "learning_rate": FLAGS.rl_learning_rate, "optimizer_str": FLAGS.optimizer_str, "loss_str": FLAGS.loss_str, "update_target_network_every": FLAGS.update_target_network_every, "discount_factor": FLAGS.discount_factor, "epsilon_decay_duration": FLAGS.epsilon_decay_duration, "epsilon_start": FLAGS.epsilon_start, "epsilon_end": FLAGS.epsilon_end, } # pylint: disable=g-complex-comprehension agents = [ dqn.DQN(idx, info_state_size, num_actions, hidden_layers_sizes, **kwargs) for idx in range(game.num_players()) ] joint_avg_policy = rl_agent_policy.JointRLAgentPolicy( game, {idx: agent for idx, agent in enumerate(agents)}, envs[0].use_observation) if FLAGS.use_checkpoints: for agent in agents: if agent.has_checkpoint(FLAGS.checkpoint_dir): agent.restore(FLAGS.checkpoint_dir) # Metrics writer will also log the metrics to stderr. just_logging = FLAGS.logdir is None or jax.host_id() > 0 writer = metric_writers.create_default_writer(FLAGS.logdir, just_logging=just_logging) # Save the parameters. writer.write_hparams(kwargs) for ep in range(1, FLAGS.num_train_episodes + 1): if ep % FLAGS.eval_every == 0: writer.write_scalars( ep, { f"agent{i}/loss": float(agent.loss) for i, agent in enumerate(agents) }) initial_states = game.new_initial_states() # Exact best response to uniform. nash_conv_obj = nash_conv.NashConv(game, uniform_policy) writer.write_scalars( ep, { f"exact_br/{state}": value for state, value in zip(initial_states, nash_conv_obj.br_values()) }) # DQN best response to uniform. pi_value = policy_value.PolicyValue(game, mfg_dist, joint_avg_policy) writer.write_scalars( ep, { f"dqn_br/{state}": pi_value.eval_state(state) for state in initial_states }) if FLAGS.use_checkpoints: for agent in agents: agent.save(FLAGS.checkpoint_dir) for p in range(game.num_players()): time_step = envs[p].reset() while not time_step.last(): agent_output = agents[p].step(time_step) action_list = [agent_output.action] time_step = envs[p].step(action_list) # Episode is over, step all agents with final info state. agents[p].step(time_step) # Make sure all values were written. writer.flush()
def create_default_writer(): return metric_writers.create_default_writer() # pylint: disable=unreachable
def run_train(self, experiment_dir, work_unit_dir, rng, yield_results=False): """Train a Dream Field and save results to work_unit_dir.""" t_start = time.time() config = self.config logging.info('Local devices: %s', jax.local_devices()) logging.info('All devices: %s', jax.devices()) ## Load CLIP encode_image, encode_text, preprocess_image, tokenize_fn = ( helpers.load_image_text_model(config.loss_model)) ## Pick a prompt template = config.get('query_template', '{query}') query = template.format(query=config.query) z_clip = encode_text(tokenize_fn(query)) ## Encode retrieval set if config.queries_r: if config.retrieve_models[0] == config.loss_model: # Reuse loss model. encode_image_r, preprocess_image_r = encode_image, preprocess_image encode_text_r, tokenize_fn_r = encode_text, tokenize_fn else: # Load new model. encode_image_r, encode_text_r, preprocess_image_r, tokenize_fn_r = ( helpers.load_image_text_model(config.retrieve_models[0])) if config.query not in config.queries_r: config.queries_r.append(config.query) z_clip_r = encode_text_r(tokenize_fn_r(config.queries_r)) true_idx_r = config.queries_r.index(config.query) assert true_idx_r >= 0 # Input query must be set of retrieval queries. del encode_text_r, tokenize_fn_r # Clean up retrieval text encoder. del encode_text, tokenize_fn # Clean up text encoder. ## Scene origin manually tracked scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999) def train_step(state, rays, key, *multistep_constants): """Perform a training iteration, optionally composed of multiple substeps. Using multiple substeps slightly reduces training time, but only one substep per training iteration is used in experiments. Args: state: Optimizer state. rays: Camera rays for rendering, shared across all substeps. key: PRNGKey for random number generation (e.g. for augmentations). *multistep_constants: Training constants that can vary across substeps. 7 arrays of constants of length config.substeps are expected: (1) lrs: learning rates (2) scs: scale factor for integrated positional encoding. Larger scales lead to a blurrier appearance. A constant sc=1 is the standard mip-NeRF IPE, and used by Dream Fields. (3) sns: standard deviation of pre-activation noise for NeRF density. Dream Fields use sn=0. density(x) = softplus(s(x) + eps), eps ~ N(0, sn^2) (4) mrs: norm of radiance mask, defining scene bounds. (5) betas: scale of beta prior loss. Dream Fields use beta=0. (6) acct: transmittance loss hyperparameter, defining the target average opacity. This is 1 - tau (target transmittance). (7) acclam: weight of transmittance loss. Returns: state: Updated optimizer state. last_augs: Augmented views of renderings from the last substep. mean_losses: Dictionary of losses averaged over replicas and substeps. scene_origin: Updated origin of the scene, based on the center of mass. """ # NOTE(jainajay): rays are shared across all substeps pmean = functools.partial(jax.lax.pmean, axis_name='batch') psum = functools.partial(jax.lax.psum, axis_name='batch') def loss_fn(params, key, sc, sn, mr, beta, acct, acclam): render_key, aug_key, key = random.split(key, 3) # Render from nerf (rgb_est_flat, _, acc_est_flat), aux = render_rays(rays=rays, variables=params, rng=render_key, config=config, sc=sc, sigma_noise_std=sn, mask_rad=mr, origin=scene_origin.value, train=True) rgb_est = scene.gather_and_reshape(rgb_est_flat, config.render_width, 3) acc_est = scene.gather_and_reshape(acc_est_flat, config.render_width, 1) # Make augmentations process specific aug_key = random.fold_in(aug_key, pid) # Perform augmentations and resize to clip_width augs = augment.augment_rendering(config, rgb_est, acc_est, aug_key) # Run through CLIP z_est = encode_image(preprocess_image(augs)) clip_loss = -(z_est * z_clip).sum(-1).mean() total_loss = clip_loss transparency_loss = config.get('transparency_loss', None) acc_mean = np.mean(acc_est) aux['losses']['acc_mean'] = acc_mean if transparency_loss == 'neg_lam_transmittance_clipped': # Compute the Dream Fields transmittance loss for scene sparsity. trans_mean = 1 - acc_mean trans_mean_clipped = np.minimum(1 - acct, trans_mean) reg = acclam * trans_mean_clipped total_loss -= reg aux['losses']['trans_mean_clipped'] = trans_mean_clipped aux['losses']['acc_reg_additive'] = reg else: assert transparency_loss is None # Compute a sparsity loss by placing a bimodal beta prior on the # per-pixel transmittance. This prior was proposed by Lombardi et al # in "Neural Volumes: Learning Dynamic Renderable Volumes from Images" # and is used only in ablations. beta_loss = np.mean( np.log(np.maximum(1e-6, acc_est_flat)) + np.log(np.maximum(1e-6, 1. - acc_est_flat))) total_loss += beta_loss * beta # Compute a weighted mean of each replica's estimated scene origin, # since replicas get a different subset of rays total_sigma = psum(aux['scene_origin_sigma']) aux['scene_origin'] = psum(aux['scene_origin'] * aux['scene_origin_sigma'] / total_sigma) # Compute loss that pushes scene content to 0 origin. We set the loss # weight zero_origin_lam = 0 in experiments so the loss is just for # logging how far the origin has drifted. origin_loss = np.sum(np.square(aux['scene_origin'])) if config.get('zero_origin_lam', 0.): total_loss += config.zero_origin_lam * origin_loss aux['losses'].update({ 'clip_loss': clip_loss, 'beta_loss': beta_loss, 'origin_loss': origin_loss, 'loss': total_loss, }) aux['augs'] = augs return total_loss, aux grad_fn = jax.value_and_grad(loss_fn, has_aux=True) # Scan over substeps def body_fn(state, step_constants): lr, step_constants = step_constants[0], step_constants[1:] grad_fn_key, _ = random.split(key, 2) (_, aux), grad = grad_fn(state.target, grad_fn_key, *step_constants) grad = pmean(grad) # all-reduce grad aux['losses'] = pmean(aux['losses']) aux['losses']['grad_norm'] = helpers.tree_norm(grad) state = state.apply_gradient(grad, learning_rate=lr) return state, aux assert len(multistep_constants) == 7 multistep_constants = np.array(multistep_constants).T if config.substeps == 1: state, aux = body_fn(state, np.squeeze(multistep_constants)) last_augs = aux['augs'] else: state, aux = jax.lax.scan(body_fn, state, multistep_constants) # Augmentations from last substep. # Shape: [n_local_aug, clip_width, clip_width, 3] last_augs = aux['augs'][-1] # Average each type of loss over substeps mean_losses = jax.tree_map(np.mean, aux['losses']) return state, last_augs, mean_losses, aux['scene_origin'] train_pstep = jax.pmap(train_step, axis_name='batch', in_axes=(0, 0, 0, None, None, None, None, None, None, None)) onp.random.seed(config.seed) n_device = jax.local_device_count() pid = jax.process_index() logging.info('n_device %d', n_device) ## Modified NeRF architecture, with swish, softplus, skips. variables, render_rays = helpers.init_nerf_model( rng.advance(1), config) state = flax.optim.Adam(config.lr0, eps=config.adam_eps).create(variables) ## Try to restore a checkpoint. restore_dir = config.get('restore_dir', experiment_dir) restore_dir = os.path.join(restore_dir, os.path.basename(work_unit_dir)) if checkpoints.latest_checkpoint(restore_dir): restored = checkpoints.restore_checkpoint(restore_dir, target={ 'origin': np.zeros(3), 'state': state, 'vars': variables }) scene_origin.value = onp.array(restored['origin']) state = restored['state'] variables = restored['vars'] logging.info('restored checkpoint from step %d', state.state.step) else: logging.info('did not find checkpoint in %s', restore_dir) ## Replicate state. step_init = state.state.step helpers.defragment() state = flax.jax_utils.replicate(state, jax.devices()) helpers.defragment() ## pmap'd rendering for test time evaluation. kwargs_test = dict(rng=None, sigma_noise_std=0.) config_test = ml_collections.ConfigDict(config) config_test.update(config.test) config_test_hq = ml_collections.ConfigDict(config_test) config_test_hq.update(config.test_hq) @functools.partial(jax.pmap, in_axes=(0, None, None, None)) def render_test_p(rays, variables, sc=1., mr=1.): return render_rays(rays=rays, variables=variables, sc=sc, mask_rad=mr, origin=scene_origin.value, config=config_test, **kwargs_test)[0] @functools.partial(jax.pmap, in_axes=(0, None, None, None)) def render_test_hq_p(rays, variables, sc=1., mr=1.): return render_rays(rays=rays, variables=variables, config=config_test_hq, sc=sc, mask_rad=mr, origin=scene_origin.value, **kwargs_test)[0] def render_test(rays, variables, sc=1., mr=1., hq=False): sh = rays[0].shape rays = [ x.reshape((jax.device_count(), -1) + x.shape[1:]) for x in rays ] if hq: out = render_test_hq_p(rays, variables, sc, mr) else: out = render_test_p(rays, variables, sc, mr) out = [x.reshape(sh[:-1] + (-1, )) for x in out] return out def render_loop(rays, variables, sc=1., mr=1., chunk=2**13, hq=False): sh = list(rays[0].shape[:-1]) rays = [x.reshape((-1, ) + x.shape[-1:]) for x in rays] outs = [ render_test([x[i:i + chunk] for x in rays], variables, sc, mr, hq=hq) for i in range(0, rays[0].shape[0], chunk) ] outs = [ np.reshape(np.concatenate([z[i] for z in outs]), sh + [-1]) for i in range(3) ] return outs ## Training loop t_total = 0. logging.info('Experiment dir %s', experiment_dir) logging.info('Work unit dir %s', work_unit_dir) gfile.makedirs(work_unit_dir) # Set up metric writer writer = metric_writers.create_default_writer( work_unit_dir, asynchronous=True, just_logging=jax.process_index() > 0) if jax.process_index() == 0: train_config = config.copy_and_resolve_references() log.write_config_json(train_config, work_unit_dir) # Scale instrinsics to different resolutions. hwf_clip_r = scene.scale_intrinsics(config.retrieve_widths[0]) hwf_base = scene.scale_intrinsics(config.render_width) hwf_video = scene.scale_intrinsics(300.) hwf_video_hq = scene.scale_intrinsics(400.) # JIT compile ray generation @jax.jit def camera_ray_batch_base(p, focal_mult): return scene.camera_ray_batch(p, *hwf_base[:2], hwf_base[2] * focal_mult) @jax.jit def sample_pose_focal(key): return scene.sample_camera(key, config.th_range, config.phi_range, config.rad_range, config.focal_mult_range) shard_rays_jit = jax.jit(functools.partial(scene.shard_rays)) def sample_iter_data(key, step): # Sample pose, focal length multiplier. pose, rad, focal_mult = sample_pose_focal(key) # Generate rays, shaped for pmap over devices. rays = camera_ray_batch_base(pose, focal_mult) rays_in = shard_rays_jit(rays) # Select rays for this process rays_in = jax.tree_map(lambda x: x[pid], rays_in) substeps = np.arange(start=step, stop=step + config.substeps, step=1) # mip-NeRF scale annealing. decays = config.mipnerf.decay_start * ( 1 - substeps / config.mipnerf.decay_iters) scs = np.maximum(1., 2**decays) # Sigma noise annealing. sns = schedule.sigma_noise_std_fn(substeps, i_split=config.sn_i_split, sn0=config.sn0, sn1=config.sn1) # Scene bounds annealing. mrs = schedule.mask_rad_fn(substeps, i_split=config.mr_i_split, mr0=config.mr0, mr1=config.mr1) # Anneal target opacity (1 - transmittance). accts = schedule.anneal_exponentially(substeps, config.acc_target_i_split, config.acc_target0, config.acc_target1) # The area of an object on the image plane grows with the focal length # and shrinks with increasing camera radius. Scale target opacity # proportionally with the squared focal multiplier and inversely # proportionally with the squared camera radius. For consistency with # early experiments that did not use this scaling, we also scale by a # constant, 1 / (4^2 * 1.2). acct_scaling = focal_mult**2 / ((rad / 4.)**2) / 1.2 accts = np.minimum(1., acct_scaling * accts) acclams = np.where(substeps < config.acc_lam_after, 0., config.acc_lam) # Beta prior encourages either 0 or 1 opacity for rays betas = np.where(substeps < config.beta_after, .0, config.get('beta_lam', .001)) # Learning rate schedule. # NOTE: vectorized calculation of lrs doesn't work with multiple substeps lrs = schedule.lr_fn(substeps, i_split=config.lr_i_split, i_end=config.iters, lr0=config.lr0, lr1=config.lr1, lr2=config.lr2, cosine_decay=config.lr_cosine_decay) return substeps, rays_in, lrs, scs, sns, mrs, betas, accts, acclams pbar = tqdm.trange(step_init, config.iters + config.substeps, config.substeps, desc='training') for i in pbar: t = time.time() substeps, rays_in, lrs, scs, sns, mrs, betas, accts, acclams = ( sample_iter_data(rng.advance(1), i)) l = substeps[-1] keys_pstep = rng.split(n_device) # NOTE: loss is averaged across substeps. new_state, augs, mean_losses, new_scene_origin = train_pstep( state, rays_in, keys_pstep, lrs, scs, sns, mrs, betas, accts, acclams) # Reduce across devices mean_losses = jax.tree_map(np.mean, mean_losses) # Gradient skipping if nan. if (helpers.all_finite_tree(mean_losses) and helpers.all_finite_tree(new_state)): state = new_state else: logging.warn( 'Skipping update on step %d. non-finite loss or state', i) continue # Update scene origin. if config.get('ema_scene_origin', False): if helpers.all_finite(new_scene_origin): scene_origin.update(new_scene_origin[0]) else: logging.warn( 'Skipping origin update on step %d. ' 'non-finite origin. old: %s skipped update: %s', i, scene_origin.value, new_scene_origin) ## Yield results, for display in colab. augs = augs.reshape( -1, *augs.shape[2:]) # devices, n_localaug, HWC->BHWC if yield_results: yield mean_losses, augs, scene_origin.value else: yield None pbar.set_description(f'Loss: {mean_losses["loss"]:.4f}') ## Logging. if i == 0: continue t_total += time.time() - t if i % config.log_scalars_every == 0: scalars = { f'losses/{key}': value for key, value in mean_losses.items() } scalars.update({ 'schedule/mipnerf_scale': scs[-1], 'schedule/lr': lrs[-1], 'schedule/mask_rad': mrs[-1], 'schedule/sigma_noise_std': sns[-1], 'schedule/beta': betas[-1], 'schedule/acc_target': accts[-1], 'schedule/acc_lam': acclams[-1], 'origin/x': scene_origin.value[0], 'origin/y': scene_origin.value[1], 'origin/z': scene_origin.value[2], 'origin/norm': np.linalg.norm(scene_origin.value), }) secs_per_iter = t_total / (l - step_init) iters_per_sec = (l - step_init) / t_total wall = time.time() - t_start scalars.update({ 'system/wall': wall, 'system/secs_per_iter': secs_per_iter, 'system/iters_per_sec': iters_per_sec, }) if i % config.render_every == 0: variables = helpers.state_to_variables(state) cam2world = scene.pose_spherical(30., -45., 4.) rays = scene.camera_ray_batch(cam2world, *hwf_clip_r) # Render with no scale manipulation. outs = render_loop(rays, variables, sc=1., mr=mrs[-1], hq=True) outs = [np.squeeze(x) for x in outs] step_images = { 'render/rgb': outs[0][None], 'render/depth': outs[1][None, Ellipsis, None], 'render/acc': outs[2][None, Ellipsis, None], } # Compute retrieval metric. if config.queries_r: z_est = encode_image_r(preprocess_image_r(outs[0][None])) cosine_sim = (z_est * z_clip_r).sum( -1) # 1d, num retrieval queries log_prob = nn.log_softmax(cosine_sim) prefix = f'val/{config.retrieve_models[0]}/retrieve_' scalars.update({ f'{prefix}cosine_sim': cosine_sim[true_idx_r], f'{prefix}loss': -log_prob[true_idx_r], f'{prefix}acc': (np.argmax(cosine_sim) == true_idx_r).astype(float) }) augs_tiled = log.make_image_grid(augs[:8]) step_images['render/augmentations'] = augs_tiled fig = plt.figure() plt.imshow(1. / np.maximum(config.near, outs[1])) plt.colorbar() plt.title('disparity') disparity = log.plot_to_image(fig) step_images['render/disparity'] = disparity writer.write_images(step=l, images=step_images) if config.render_lq_video and config.video_every and ( i % config.video_every == 0 or i + 1 == config.iters): def rays_theta(th): cam2world = scene.pose_spherical(th, -30., 4.) return scene.camera_ray_batch(cam2world, *hwf_video) th_range = np.linspace(0, 360, 60, endpoint=False) frames_all = [ render_loop(rays_theta(th), variables, scs[-1], mrs[-1], hq=False) for th in tqdm.tqdm(th_range, desc='render video') ] videos = [[np.squeeze(f[i]) for f in frames_all] for i in range(3)] for video, label in zip(videos, 'rgb depth acc'.split()): scale = (label == 'depth') log.log_video(None, video, 'frames', label, l, work_unit_dir, scale=scale) if i % config.log_scalars_every == 0: writer.write_scalars(step=l, scalars=scalars) if i % config.flush_every == 0: writer.flush() defrag_every = config.get('defragment_every', default=0) if defrag_every and i % defrag_every == 0: helpers.defragment() if config.get( 'checkpoint_every') and i % config.checkpoint_every == 0: saved_path = checkpoints.save_checkpoint( ckpt_dir=work_unit_dir, target={ 'state': flax.jax_utils.unreplicate(state), 'vars': helpers.state_to_variables(state), 'origin': np.array(scene_origin.value), }, step=l, keep=1, overwrite=True, keep_every_n_steps=config.get('keep_every_n_steps', None)) logging.info('saved checkpoint to %s', saved_path) # Make a higher res, higher frame rate video. if config.render_hq_video and (config.get('hq_video_every', None) and i % config.hq_video_every == 0 or i == config.iters): my_rays = lambda c2w: scene.camera_ray_batch( c2w, *hwf_video_hq) th_range = np.linspace(0, 360, 240, endpoint=False) poses = [scene.pose_spherical(th, -30., 4.) for th in th_range] variables = helpers.state_to_variables(state) frames_all = [ render_loop(my_rays(pose), variables, 1., config.mr1, hq=True) for pose in tqdm.tqdm(poses, 'render hq video') ] videos = [[onp.array(np.squeeze(f[j])) for f in frames_all] for j in range(3)] meta_path = os.path.join(work_unit_dir, 'meta_hq.npy') with gfile.GFile(meta_path, 'wb') as f: logging.info( 'saving metadata for rendered hq frames to %s', meta_path) onp.save( f, dict(poses=onp.array(poses), hwf=onp.array(hwf_video_hq))) for video, label in zip(videos, 'rgb depth acc'.split()): scale = (label == 'depth') log.log_video(None, video, 'frames_hq', label, i, work_unit_dir, scale=scale) writer.flush() writer.close() logging.info('%f sec elapsed total', time.time() - t_start)
def train(config: ml_collections.ConfigDict): """Run training.""" # Establish host information local_device_count = jax.local_device_count() host_count = jax.process_count() host_id = jax.process_index() task = task_registry.get_registered_task(config.task_name) start_step = 0 rng = jax.random.PRNGKey(config.seed) model_config = ml_collections.FrozenConfigDict(config.model_config) model = task.build_model(model_config) # Initialization needs to be pmapped because models use collective ops. # Create dummy input dummy_input = { key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim) for key, value in task.dummy_input(config).items() } rng, init_rng = jax.random.split(rng) init_rng = jax.random.split(init_rng, local_device_count) logging.info('Initializing model.') initial_variables = jax.pmap( model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input, True) logging.info('Finished initializing model.') initial_variables = initial_variables.unfreeze() if config.load_weights is not None: logging.info('Loading model weights from file') loaded_variables = task.load_weights(config) unexpected, missing = checkpoint_utils.merge_nested_dicts( initial_variables, loaded_variables) logging.info('*** Unexpected features: ***') for feature_name in unexpected: logging.info('\t%s', feature_name) logging.info('*** Missing features: ***') for feature_name in missing: logging.info('\t%s', feature_name) model_vars = { key: value for key, value in initial_variables.items() if key != 'params' } learning_rate_fn = optim_utils.create_learning_rate_scheduler( learning_rate=config.learning_rate, warmup=config.warmup, warmup_steps=config.get('warmup_steps', None), linear_decay=config.linear_decay, max_steps=config.num_train_steps, decay_minimum_factor=config.get('decay_minimum_factor', None), ) if config.weight_decay_exclude is not None: decay_mask = optim_utils.create_dict_mask(initial_variables['params'], config.weight_decay_exclude) else: decay_mask = None tx = optax.adamw( learning_rate=learning_rate_fn, weight_decay=config.weight_decay, b1=0.9, b2=0.999, eps=1e-6, mask=decay_mask) if config.grad_clip is not None: tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip)) ignore_k_nans = config.get('ignore_k_nans') if ignore_k_nans is not None: tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans) loss_fn = task.make_loss_fn(config) train_state = ts.TrainState.create( apply_fn=loss_fn, params=jax_utils.unreplicate(initial_variables['params']), tx=tx, ) # We access model params only from train state. del initial_variables # Restore unreplicated train state from last checkpoint train_state = checkpoints.restore_checkpoint(config.model_dir, train_state) # Grab last step. start_step = int(train_state.step) writer = metric_writers.create_default_writer( config.model_dir, just_logging=jax.process_index() > 0) if start_step == 0: writer.write_hparams(config.to_dict()) dropout_rngs = jax.random.split(rng, local_device_count) del rng # Load datasets logging.info('Loading dataset.') # Make sure we don't re-use same data if we load weights or checkpoint seed = config.seed + start_step if config.load_weights: seed = seed + hash(config.load_weights) name_to_features = task.get_name_to_features(config) preprocess_fn = task.make_preprocess_fn(config) collater_fn = task.make_collater_fn(config) train_data = data_utils.load_multi_dataset( datasets_config=config.train_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=True, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, ) train_iter = iter(train_data) pad_eval = config.get('pad_eval', False) if pad_eval: logging.info('Eval data is padded such that none of samples are dropped.') else: logging.warn('Eval data is NOT padded -- some samples might be dropped.') eval_data = data_utils.load_multi_dataset( datasets_config=config.eval_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=False, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, pad_eval=pad_eval, ) eval_data = list(eval_data) logging.info('Loaded %d samples for evaluation.', len(eval_data)) # Setup postprocessing_fn for saving samples occasionally. if config.get('save_samples_every_steps') is not None: if config.get('save_samples_every_steps') % config.eval_every_steps != 0: raise ValueError( '`eval_every_steps` must divide `save_samples_every_steps`.') postprocessing_fn = task.make_output_postprocess_fn(config) # Training loop logging.info('Starting training.') # Replicate train state. train_state = jax_utils.replicate(train_state) # compile multidevice versions of train/eval/predict step p_train_step = jax.pmap( functools.partial( train_step, model_config=model_config, ), axis_name='batch', donate_argnums=(0,), ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, model_config=model_config, ), axis_name='batch') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and perform a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = jax.tree_map(jnp.asarray, train_iter.get_next()) train_state, metrics = p_train_step( train_state, model_vars, batch, dropout_rngs, ) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) metrics_sums = jax.tree_map(jnp.sum, train_metrics) summary = metric_utils.process_metrics(metrics_sums, prefix='train') summary['learning_rate'] = learning_rate_fn(step) writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_auxiliary = evaluate( eval_step_fn=p_eval_step, train_state=train_state, model_vars=model_vars, eval_data=eval_data, ) writer.write_scalars(step, eval_results) if config.get('save_samples_every_steps') is not None: with report_progress.timed('save_samples'): if config.get('save_first_batch_only', 'True'): postprocessing_input = [eval_auxiliary[0]] eval_processed = [ postprocessing_fn(batch, auxiliary_output) for batch, auxiliary_output in eval_auxiliary ] data_utils.save_samples_to_json(eval_processed, config, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step) if (config.save_checkpoints and save_checkpoint and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving checkpoint at step %s', step) checkpoints.save_checkpoint( config.model_dir, jax_utils.unreplicate(train_state), step, keep=config.get('keep_checkpoints', 1), keep_every_n_steps=config.get('keep_checkpoint_every_steps'), ) save_model = ( config.save_every_steps and (step % config.save_every_steps == 0 or is_last_step) and step != 0) if (save_model and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving weights at step %s', step) save_path = os.path.join(config.model_dir, 'weights', 'step' + str(step)) # By default, save only encoder weights weights = jax_utils.unreplicate(train_state).params['encoder'] checkpoint_utils.save_weights(save_path, weights)
def run_train(self, experiment_dir, work_unit_dir, rng): """Training loop with fixed number of steps and checkpoint every steps.""" del experiment_dir # unused tf.io.gfile.makedirs(work_unit_dir) config = self.config total_bs = config.train.batch_size assert total_bs % jax.device_count() == 0, ( f'num total devices {jax.device_count()} must divide the batch size ' f'{total_bs}') device_bs = total_bs // jax.device_count() logging.info('total_bs=%d device_bs=%d', total_bs, device_bs) # Logging setup writer = metric_writers.create_default_writer( work_unit_dir, just_logging=jax.host_id() > 0) if jax.host_id() == 0: utils.write_config_json(config, os.path.join(work_unit_dir, 'config.json')) # Build input pipeline logging.info('Substeps per training step: %d', config.train.substeps) train_ds = self.dataset.get_tf_dataset( split='train', batch_shape=( jax.local_device_count(), # for pmap config.train.substeps, # for lax.scan over multiple substeps device_bs, # batch size per device ), global_rng=jax.random.PRNGKey(config.seed), repeat=True, shuffle=True, augment=True, shard_id=jax.host_id(), num_shards=jax.host_count()) train_iter = utils.numpy_iter(train_ds) eval_ds = self.dataset.get_tf_dataset( split='eval', batch_shape=(jax.local_device_count(), device_bs), global_rng=jax.random.PRNGKey(config.seed), repeat=True, shuffle=True, augment=False, shard_id=jax.host_id(), num_shards=jax.host_count()) eval_iter = utils.numpy_iter(eval_ds) samples_shape = (device_bs, *self.dataset.data_shape) self.p_gen_samples = utils.dist( functools.partial(self._gen_samples, samples_shape=samples_shape), accumulate='concat', axis_name='batch') # Set up model and training state state = jax.device_get(self.make_init_state()) checkpoint_dir = os.path.join(work_unit_dir, 'checkpoints') state = checkpoints.restore_checkpoint(checkpoint_dir, state) initial_step = int(state.step) state = flax.jax_utils.replicate(state) # Training step train_step = functools.partial(self.step_fn, next(rng), True) train_step = functools.partial(jax.lax.scan, train_step) # for substeps train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,)) # Eval step (does not modify parameters; no substeps) eval_base_rng = next(rng) # Training loop logging.info('Entering training loop at step %i', initial_step) utils.assert_synced(state) last_log_time = last_ckpt_time = time.time() prev_step = initial_step with metric_writers.ensure_flushes(writer): for batch in train_iter: state, metrics = train_step(state, batch) new_step = int(state.step[0]) assert new_step == prev_step + config.train.substeps # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d', 5, new_step) # Log metrics if new_step % config.train.log_loss_every_steps == 0: # Unreplicate metrics, average over substeps, and cast to python float metrics = jax.device_get(flax.jax_utils.unreplicate(metrics)) def avg_over_substeps(x): assert x.shape[0] == config.train.substeps return float(x.mean(axis=0)) metrics = jax.tree_map(avg_over_substeps, metrics) metrics['train/steps_per_sec'] = float( config.train.log_loss_every_steps / (time.time() - last_log_time)) writer.write_scalars(new_step, metrics) last_log_time = time.time() # Eval should_eval = new_step % config.train.eval_every_steps == 0 if prev_step == 0 or should_eval: # Samples samples_to_log = { 'eval/samples': self.get_model_samples( params=state.ema_params, rng=next(rng)) } if samples_to_log: assert all(v.shape == (total_bs, *self.dataset.data_shape) for v in samples_to_log.values()) # tf.summary.image asks for a batch, so insert a new axis writer.write_images( new_step, { k: utils.np_tile_imgs(v.astype('uint8'))[None, :, :, :] for k, v in samples_to_log.items() }) # Eval metrics if config.train.get('calc_eval_metrics', True): eval_metrics = self._calc_eval_metrics( state=state, eval_iter=eval_iter, eval_steps=config.train.get('eval_number_steps', self.dataset.num_eval // total_bs), eval_base_rng=eval_base_rng, total_bs=total_bs) if eval_metrics is not None: writer.write_scalars(new_step, eval_metrics) # Checkpointing: only if checkpoint_every_secs is not None. if config.train.checkpoint_every_secs is not None: should_ckpt = ( time.time() - last_ckpt_time >= config.train.checkpoint_every_secs) should_ckpt = ( prev_step == 0 or new_step == config.train.num_train_steps or should_ckpt) else: should_ckpt = False if should_ckpt and jax.host_id() == 0: checkpoints.save_checkpoint( checkpoint_dir, flax.jax_utils.unreplicate(state), step=new_step, keep=3) last_ckpt_time = time.time() # Keep extra checkpoints without removal. Training does not resume # from these checkpoints. if (('retain_checkpoint_every_steps' in config.train) and ((new_step % config.train.retain_checkpoint_every_steps == 0) or (new_step == config.train.num_train_steps)) and (jax.host_id() == 0)): # Below, overwrite=True because training might resume from a # checkpoint from an earlier step than the latest retained checkpoint, # causing the latest retained checkpoint to be overwritten. checkpoints.save_checkpoint( os.path.join(work_unit_dir, 'retained_checkpoints'), flax.jax_utils.unreplicate(state), step=new_step, keep=int(1e10), overwrite=True) prev_step = new_step if new_step == config.train.num_train_steps: logging.info('Finished training for %d iterations.', new_step) break
def train(env, agent, loss_func, horizon, config, workdir=None): """Main training loop. config - num_episodes - episodes_per_eval - training_env_batch_size - eval_env_batch_size = 32 - optimizer - learning_rate - seed = 1 """ print(config) if workdir is not None: writer = metric_writers.create_default_writer( logdir=workdir, just_logging=jax.process_index() != 0) writer.write_hparams(dict(config)) key = jax.random.PRNGKey(config.seed) key_train_agent, key_eval_agent, key_train_env, key_eval_env, key_train, key = jax.random.split(key, 6) key_train_envs = jax.random.split(key_train_env, config.training_env_batch_size) key_train_agents = jax.random.split(key_train_agent, config.training_env_batch_size) key_eval_envs = jax.random.split(key_eval_env, config.eval_env_batch_size) key_eval_agents = jax.random.split(key_eval_agent, config.eval_env_batch_size) #TODO(danielsuo): The following vmap code does not work. train_env_start_states, train_env_init_obs = jax.vmap(env.init)(key_train_envs) eval_env_start_states, eval_env_init_obs = jax.vmap(env.init)(key_eval_envs) print(train_env_start_states) print(train_env_init_obs) # qtrain_init_list = list(map(env.init, key_train_envs)) # qtrain_env_start_states = [a for (a,_) in qtrain_init_list] # qtrain_env_init_obs = [b for (_,b) in qtrain_init_list] # print(qtrain_env_start_states) # print(qtrain_env_init_obs) # eval_init_list = list(map(env.init, key_eval_envs)) # eval_env_start_states = [a for (a,_) in eval_init_list] # eval_env_init_obs = [b for (_,b) in eval_init_list] train_agent_start_states = jax.vmap(agent.init)(key_train_agents) eval_agent_start_states = jax.vmap(agent.init)(key_eval_agents) if config.optimizer == "Adam": optim = optax.adam(learning_rate=config.learning_rate) else: # default is SGD optim = optax.sgd(learning_rate=config.learning_rate) optim_state = optim.init(agent) for episode in range(0, config.num_episodes, config.episodes_per_eval): # Eval Step tt = time.time() eval_rollouts = apg_parallel_rollouts(env, eval_env_start_states, eval_env_init_obs, agent, eval_agent_start_states, horizon, loss_func) test_score = eval_rollouts.losses.mean() print(f"TESTING episode {episode} - score:{test_score} - time:{time.time()-tt}") # Training Step tt = time.time() agent, optim_state, losses = train_chunk(config.episodes_per_eval, optim, optim_state, agent, env, train_env_start_states, train_env_init_obs, train_agent_start_states, horizon, loss_func) done_eps = episode + config.episodes_per_eval - 1 print(f"TRAINING: episode {done_eps} - score:{losses[0]} - time {time.time() - tt}") if workdir is not None: for (i, loss) in enumerate(reversed(losses)): writer.write_scalars(episode+i, {"train_score": loss}) writer.write_scalars(episode, {"test_score": test_score}) return optim_state, agent
def main(_): config = FLAGS.config # Unpack total and warmup steps total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names( dist_shift) # LR / Optimization Flags print('wandb hyperparameters:') print({ 'batch_size': config.batch_size, 'grad_clip_norm': config.grad_clip_norm, 'weight_decay': config.weight_decay, 'total_steps': config.total_steps, 'lr': config.lr, 'fast_weight_lr_multiplier': config.fast_weight_lr_multiplier }) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) tf.io.gfile.makedirs(output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get(dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint:disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img // batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') model_dict = vit_utils.initialize_model('batchensemble', config) model = model_dict['model'] ens_size = model_dict['ens_size'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['batchensemble_head']['bias'] = jnp.full_like( params['batchensemble_head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['batchensemble_head']['kernel'] = jnp.full_like( params['batchensemble_head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels): tiled_logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) loss_name = config.get('loss', 'sigmoid_xent') # TODO(dusenberrymw,zmariet): Clean up and generalize this. if loss_name == 'sigmoid_xent': ens_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) else: # softmax ens_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) losses = getattr(train_utils, loss_name)(logits=ens_logits, labels=labels[:, :config.num_classes], reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) def batch_loss_fn(params, images, labels, rngs): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs=rngs) labels = jnp.tile(labels, (ens_size, 1)) loss_fn = getattr(train_utils, config.get('loss', 'sigmoid_xent')) loss = jnp.mean(loss_fn(logits=logits, labels=labels)) return loss, dict() @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0, 1)) def update_fn(opt, rngs, lr, images, labels): return batchensemble_utils.update_fn_be( opt=opt, rngs=rngs, lr=lr, images=images, labels=labels, batch_loss_fn=batch_loss_fn, weight_decay_fn=weight_decay_fn, max_grad_norm_global=config.get('grad_clip_norm', None), fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier', None)) # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path reint_params = ('batchensemble_head/bias', 'batchensemble_head/kernel', 'batchensemble_head/fast_weight_alpha', 'batchensemble_head/fast_weight_gamma') if config.get('only_eval', False) or not config.get('reint_head', True): reint_params = [] checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=rng, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=None, default_reinit_params=reint_params, config=config) train_loop_rngs = {'dropout': checkpoint_data.train_loop_rngs} opt_cpu = checkpoint_data.optimizer accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # eval_loss = { # eval_name: -jnp.inf for eval_name, _ in eval_iter_splits.items()} # fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) # TODO(zmariet): Find better way to cut down iteration advancement cost. if not config.get('disable_preemption_reproducibility', False): train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): if not config.get('only_eval', False): opt_repl, train_loop_rngs, extra_measurements = update_fn( opt_repl, train_loop_rngs, lr_repl, train_batch['image'], train_batch['labels']) if jax.process_index() == 0: profiler(step) # Checkpoint saving if not config.get('only_eval', False) and train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( train_loop_rngs=train_loop_rngs, optimizer=opt_cpu, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if not config.get('only_eval', False) and train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Keep to return for reproducibility tests. # train_loss = train_measurements['training_loss'] # Report validation performance if config.get('only_eval', False) or train_utils.itstime( step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation sets...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn(opt_repl.target, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results(output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break if config.get('only_eval', False): break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close()
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) # Number of local devices for this host. n_devices = jax.local_device_count() if config.batch_size % n_devices: raise ValueError( "Batch size must be divisible by the number of devices") vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=config.dataset_name, eval_dataset_name=config.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path, target_vocab_size=config.vocab_size, batch_size=config.batch_size, max_corpus_chars=config.max_corpus_chars, max_length=config.max_target_length, max_eval_length=config.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) input_shape = (config.batch_size, config.max_target_length) target_shape = (config.batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 1: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5) ] metrics_all = [] with metric_writers.ensure_flushes(writer): for step, batch in zip(range(start_step, config.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Save a checkpoint on one host after every checkpoint_freq steps. if (config.save_checkpoints and step % config.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info("Gathering training metrics.") metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) metrics_all = [] # Eval Metrics logging.info("Gathering evaluation metrics.") eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(config.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) eval_summary = {"eval_" + k: v for k, v in eval_summary.items()} writer.write_scalars(step, eval_summary) # Translation and BLEU Score. logging.info("Translating evaluation dataset.") t_inference_start = time.time() sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop pred_batch) pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache, eos_id, config.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info( "Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources)) logging.info("Translation time: %.4f s step %d.", time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars})
def main(argv): del argv config = FLAGS.config output_dir = FLAGS.output_dir seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('dataset_dir'): logging.info('data_dir=%s', config.dataset_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') fillin = lambda *_: None # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d dev per host (%d dev total), that is a %d per-device batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=fillin(config.get('data_dir'))) logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:]) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=fillin(data_dir)) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=fillin(data_dir)) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, config.val_split, config.pp_eval, config.get('dataset_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=fillin(config.get('data_dir'))) steps_per_epoch = ntrain_img / batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info( 'Running for %d steps, that means %f epochs and %f steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) # Specify Gaussian process layer configs. use_gp_layer = True gp_config = config.get('gp_layer', {}) gp_layer_kwargs = get_gp_kwargs(gp_config) # Process ViT backbone model configs. vit_kwargs = config.get('model') het_kwargs = config.get('het') model = ub.models.vision_transformer_hetgp( num_classes=config.num_classes, use_gp_layer=use_gp_layer, vit_kwargs=vit_kwargs, gp_layer_kwargs=gp_layer_kwargs, multiclass=het_kwargs.multiclass, temperature=het_kwargs.temperature, mc_samples=het_kwargs.mc_samples, num_factors=het_kwargs.num_factors, param_efficient=het_kwargs.param_efficient) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) rng, diag_noise_rng, standard_noise_rng = jax.random.split(rng, num=3) init_rngs = { 'params': rng, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng } variables = model.init(init_rngs, dummy_input, train=False) # Split model parameters into trainable and untrainable collections. states, params = variables.pop('params') del variables # Set bias in the head to a low value, such that loss is small initially. params = flax.core.unfreeze(params) if use_gp_layer: # Modify the head parameter in the GP head. params['head']['loc_layer']['output_layer'][ 'bias'] = jnp.full_like( params['head']['loc_layer']['output_layer']['bias'], config.get('init_head_bias', 0)) else: params['vit_backbone']['head']['bias'] = jnp.full_like( params['vit_backbone']['head']['bias'], config.get('init_head_bias', 0)) return params, states (rng, rng_init, rng_dropout, diag_noise_rng, standard_noise_rng) = jax.random.split(rng, num=5) params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, states, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply(variable_dict, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels[:, :config.num_classes], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, states, images, labels, mask): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply(variable_dict, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask, states): variable_dict = {'params': flax.core.freeze(params), **states} _, outputs = model.apply(variable_dict, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, states, lr, reset_covmat, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split( rng_model_local, num=3) def loss_fn(params, states, images, labels): # Specify mutable collection to update untrainable GP parameters. variable_dict = {'params': flax.core.freeze(params), **states} model_results, updated_states = model.apply( variable_dict, images, train=True, rngs={ 'dropout': rng_model_local, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }, mutable=list(states.keys())) logits, _ = model_results loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) return loss, updated_states # Performs exact covariance update (i.e., reset precision matrix resetting # at begining of new epoch) if covmat_momentum is a null value. if gp_config.get('covmat_momentum', -1.) < 0: # Resets precision matrix to Identity * ridge_penalty if at the begining # of a new epoch. This should be done before accumulate gradient. ridge_penalty = gp_config.get('ridge_penalty', 1.) prec_mat_old = states['laplace_covariance']['head'][ 'covmat_layer']['precision_matrix'] prec_mat_new = ( (1. - reset_covmat) * prec_mat_old + reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty) states = flax.core.unfreeze(states) states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] = prec_mat_new states = flax.core.freeze(states) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, s), g = train_utils.accumulate_gradient_with_states( jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. do_grad_clip = config.get('grad_clip_norm', -1.) > 0. if config.get('grad_accum_steps', 1) == 1 or do_grad_clip: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if do_grad_clip: g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, s, l, rng, measurements default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias', 'head/kernel', 'head/bias') rng, train_loop_rngs = jax.random.split(rng) checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=states_cpu, default_reinit_params=default_reinit_params, config=config) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer states_cpu = checkpoint_data.fixed_model_states accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) # Prepare the precision matrix resetting schedule, and pre-fetch it to device. reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0) reset_covmat_iter = train_utils.prefetch_scalar( map(reset_covmat_fn, range(first_step, total_steps)), nprefetch=config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) states_repl = flax_utils.replicate(states_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') if 'fewshot' in config: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Makes sure log_eval_steps is same as steps_per_epoch. This is because # the precision matrix needs to be updated fully (at the end of each epoch) # when eval takes place. log_eval_steps = steps_per_epoch # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl, reset_covmat_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter, reset_covmat_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, train_loop_rngs, extra_measurements) = update_fn(opt_repl, states_repl, lr_repl, reset_covmat_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if train_utils.itstime(step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # For GP layer, we will also do the same for untrainable parameters # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, fixed_model_states=states_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime(step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, states_repl, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, states_repl, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip( probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state( batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[ f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[ f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # There are two entries in the ood_ds dict (in-dist, ood), and that this # section computes metrics using both pieces. This is in contrast to # normal validation eval above where we eval metrics separately for each # val split in val_ds. if ood_ds and config.ood_methods: def make_sngp_eval_fn(states): def sngp_eval_fn(params, images, labels, mask): return evaluation_fn(params=params, states=states, images=images, labels=labels, mask=mask) return sngp_eval_fn ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, make_sngp_eval_fn(states_repl), opt_repl.target, n_prefetch=config.get('prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, datasets=config.fewshot.datasets, states=states_repl) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results
def train_controller( controller, sim, pip_feed="parallel", # or "sequential" mode="multipip", # or "singular" duration=0.87, dt=0.03, epochs=100, use_noise=False, optimizer=optax.adamw, optimizer_params={ "learning_rate": 1e-3, "weight_decay": 1e-4 }, loss_fn=lambda x, y: (jnp.abs(x - y)).mean(), scheduler="Cosine", tensorboard_dir=None, model_parameters={}, # used for tensorboard print_loss=1, ): """train controller.""" peep = 5 if mode == "multipip": pips = [10, 15, 20, 25, 30, 35] elif mode == "singular": pips = [35] # setup optimizer optim_params = copy.deepcopy(optimizer_params) if scheduler == "Cosine": if pip_feed == "parallel": steps_per_epoch = 1 elif pip_feed == "sequential": steps_per_epoch = len(pips) decay_steps = int(epochs * steps_per_epoch) print("steps_per_epoch:" + str(steps_per_epoch)) print("decay_steps:" + str(decay_steps)) cosine_scheduler_fn = optax.cosine_decay_schedule( init_value=optim_params["learning_rate"], decay_steps=decay_steps) optim_params["learning_rate"] = cosine_scheduler_fn print("optim_params:" + str(optim_params)) optim = optimizer(**optim_params) optim_state = optim.init(controller) # setup Tensorboard writer if tensorboard_dir is not None: trial_name = str(model_parameters) write_path = tensorboard_dir + trial_name summary_writer = metric_writers.create_default_writer( logdir=write_path, just_logging=jax.process_index() != 0) # summary_writer = tensorboard.SummaryWriter(write_path) summary_writer.write_hparams(model_parameters) tt = jnp.linspace(0, duration, int(duration / dt)) losses = [] for epoch in range(epochs): if pip_feed == "parallel": value, grad = jax.value_and_grad(rollout_parallel)(controller, sim, tt, use_noise, peep, jnp.array(pips), loss_fn) updates, optim_state = optim.update(grad, optim_state, controller) controller = optax.apply_updates(controller, updates) per_step_loss = value / len(tt) losses.append(per_step_loss) if epoch % print_loss == 0: # make new controller with trained parameters and normal clamp score = test_controller(controller, sim, pips, peep) print(f"Epoch: {epoch}\tLoss: {score:.2f}") if tensorboard_dir is not None: summary_writer.write_scalars(epoch, {"score": score}) if pip_feed == "sequential": for pip in pips: value, grad = jax.value_and_grad(rollout)(controller, sim, tt, use_noise, peep, pip, loss_fn, jnp.array(0.)) updates, optim_state = optim.update(grad, optim_state, controller) controller = optax.apply_updates(controller, updates) per_step_loss = value / len(tt) losses.append(per_step_loss) if epoch % print_loss == 0: # make new controller with trained parameters and normal clamp score = test_controller(controller, sim, pips, peep) print(f"Epoch: {epoch}, pip: {pip}\tLoss: {score:.2f}") if tensorboard_dir is not None: summary_writer.write_scalars(epoch, {"per_step_loss": score}) return controller, per_step_loss, score
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: Final TrainState. """ writer = metric_writers.create_default_writer( logdir=workdir, just_logging=jax.process_index() != 0) rng = random.PRNGKey(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.process_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder(config.dataset) train_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache) eval_iter = create_input_iter(dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache) steps_per_epoch = (dataset_builder.info.splits['train'].num_examples // config.batch_size) if config.num_train_steps == -1: num_steps = int(steps_per_epoch * config.num_epochs) else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation'].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256. model_cls = getattr(models, config.model) model = create_model(model_cls=model_cls, half_precision=config.half_precision) learning_rate_fn = create_learning_rate_fn(config, base_learning_rate, steps_per_epoch) state = create_train_state(rng, config, model, image_size, learning_rate_fn) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) state = jax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') train_metrics = [] hooks = [] if jax.process_index() == 0: hooks += [ periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics_last_t = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) for h in hooks: h(step) if step == step_offset: logging.info('Initial compilation completed.') if config.get('log_every_steps'): train_metrics.append(metrics) if (step + 1) % config.log_every_steps == 0: train_metrics = common_utils.get_metrics(train_metrics) summary = { f'train_{k}': v for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items() } summary['steps_per_second'] = config.log_every_steps / ( time.time() - train_metrics_last_t) writer.write_scalars(step + 1, summary) train_metrics = [] train_metrics_last_t = time.time() if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch eval_metrics = [] # sync batch statistics across replicas state = sync_batch_stats(state) for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_map(lambda x: x.mean(), eval_metrics) logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100) writer.write_scalars( step + 1, {f'eval_{key}': val for key, val in summary.items()}) writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: state = sync_batch_stats(state) save_checkpoint(state, workdir) # Wait until computations are done before exiting jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() return state
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) splits = input_pipeline.create_datasets(config, data_rng) num_classes = splits.info.features["label"].num_classes train_iter = iter(splits.train) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = splits.train.cardinality().numpy() steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 32 but scale it according to our batch size. global_batch_size = config.per_device_batch_size * jax.device_count() base_learning_rate = config.learning_rate * global_batch_size / 32.0 learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) # Initialize model. rng, model_rng = jax.random.split(rng) model, state = create_train_state( config, model_rng, input_shape=splits.train.element_spec["input"].shape[1:], num_classes=num_classes) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, {"train_iter": train_iter}, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Count number of trainable parameters. This must be done before replicating # the state to avoid double-counting replicated parameters. param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target)) # Distribute training over local devices. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.weight_decay), axis_name=_PMAP_AXIS_NAME) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if initial_step == 1: writer.write_hparams(dict(config)) # Log the number of trainable params. writer.write_scalars(initial_step, {"param_count": param_count}) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics = None with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceContext("train", step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics_update = p_train_step(state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None # When combining train and eval, we do not evaluate while training. if ((step % config.eval_every_steps == 0 or is_last_step) and not config.combine_train_val_and_eval_on_test): with report_progress.timed("eval"): eval_metrics = evaluate(model, state, splits.validation, config.num_eval_steps) writer.write_scalars(step, eval_metrics.compute()) if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed("checkpoint"): ckpt.save(flax_utils.unreplicate(state)) if is_last_step and config.combine_train_val_and_eval_on_test: # Evaluate a single time on the test set when requested. with report_progress.timed("test"): test_metrics = evaluate(model, state, splits.test, config.num_eval_steps) writer.write_scalars(step, test_metrics.compute()) logging.info("Finishing training at step %d", num_train_steps)
def train_simulator( dataset, model, num_boundary_models, activation_fn_name, R, C, # idx 0 to num_boundary_models-1 are boundary models, # idx num_boundary_models is default_model train_key="train", test_key="test", batch_size=512, epochs=500, optimizer=optax.adamw, optimizer_params={ "learning_rate": 1e-3, "weight_decay": 1e-4 }, patience=10, lr_decay_factor=0.1, scheduler="ReduceLROnPlateau", # or "Cosine" loss_fn=lambda x, y: (jnp.abs(x - y)).mean(), print_loss=10, use_tensorboard=False, mode="train", user_name="alexjyu-brain", tb_dir=None, ): """train simulator.""" # evaluate on these at end of epoch for key in ["train", "test"]: dataset.data[key] = (jnp.array(dataset.data[key][0]), jnp.array(dataset.data[key][1])) X_train, y_train = dataset.data[train_key] X_test, y_test = dataset.data[test_key] # set up optimizer and lr scheduler lr_mult = 1.0 if scheduler == "ReduceLROnPlateau": optim = optimizer(**optimizer_params) patience_cnt = 0 prev_loss = float("inf") elif scheduler == "Cosine": steps_per_epoch = float(X_train.shape[0] / batch_size) decay_steps = int((epochs + 1) * steps_per_epoch) logging.info("steps_per_epoch: %s", str(steps_per_epoch)) logging.info("decay_steps: %s", str(decay_steps)) cosine_scheduler_fn = optax.cosine_decay_schedule( init_value=optimizer_params["learning_rate"], decay_steps=decay_steps) optimizer_params["learning_rate"] = cosine_scheduler_fn logging.info("optimizer_params: %s", str(optimizer_params)) optim = optimizer(**optimizer_params) optim_state = optim.init(model) loop_over_loader_partial = functools.partial( loop_over_loader, optim=optim, rollout_fn=rollout, scheduler=scheduler) # Tensorboard writer if use_tensorboard: config = copy.deepcopy(model.default_model_parameters) del config["activation_fn"] config["activation_fn_name"] = activation_fn_name if mode == "train": file_name = str(config) write_path = tb_dir + file_name summary_writer = metric_writers.create_default_writer( logdir=write_path, just_logging=jax.process_index() != 0) summary_writer = tensorboard.SummaryWriter(write_path) summary_writer.write_hparams(dict(config)) # Main Training Loop prng_key = jax.random.PRNGKey(0) for epoch in range(epochs + 1): if epoch % 10 == 0: logging.info("epoch: %s", str(epoch)) X, y, prng_key = get_shuffled_and_batched_data(dataset, batch_size, train_key, prng_key) if epoch == 0: logging.info("X.shape: %s", str(X.shape)) logging.info("y.shape: %s", str(y.shape)) (model, optim_state, lr_mult, loss), _ = jax.lax.scan(loop_over_loader_partial, (model, optim_state, lr_mult, 0.), (X, y)) """for i in range(X.shape[0]): carry = (model, optim_state, lr_mult, 0.) carry, _ = loop_over_loader_partial(carry, (X[i], y[i])) model, optim_state, lr_mult, loss = carry """ if scheduler == "ReduceLROnPlateau": if loss > prev_loss: patience_cnt = patience_cnt + 1 else: patience_cnt = 0 if patience_cnt == patience: lr_mult = lr_mult * lr_decay_factor patience_cnt = 0 prev_loss = loss if epoch % print_loss == 0: if scheduler == "ReduceLROnPlateau": logging.info("loss: %s", str(loss)) logging.info("prev_loss: %s", str(prev_loss)) logging.info("patience_cnt: %s", str(patience_cnt)) logging.info("lr_mult: %s", str(lr_mult)) # expensive end-of-epoch eval, just for intuition train_loss = map_rollout_over_batch(model, (X_train, y_train), rollout) # cross-validation test_loss = map_rollout_over_batch(model, (X_test, y_test), rollout) if epoch % print_loss == 0: logging.info( f"Epoch {epoch:2d}: train={train_loss.item():.5f}, test_loss={test_loss.item():.5f}" ) logging.info("-----------------------------------") if use_tensorboard: summary_writer.write_scalars(epoch, {"train_loss": train_loss}) summary_writer.write_scalars(epoch, {"test_loss": test_loss}) if use_tensorboard: summary_writer.flush() logging.info("finished looping over epochs") return model, test_loss
def main(config, output_dir): seed = config.get('seed', 0) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) tf.io.gfile.makedirs(output_dir) # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, split=config.val_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) write_note('Initializing model...') logging.info('config.model = %s', config.model) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.model) ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # params is a dict of the form: # {'model_1': params_model_1, 'model_2': params_model_2, ...} # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) loss_as_str = config.get('loss', 'sigmoid_xent') ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str) label_indices = config.get('label_indices') logging.info('!!! mask %s, label_indices %s', mask, label_indices) if label_indices: ens_logits = ens_logits[:, label_indices] # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, loss_as_str)( logits=ens_logits, labels=labels[:, :( len(label_indices) if label_indices else config.num_classes)], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [ens_logits, labels, ens_prelogits, mask], axis_name='batch') return ncorrect, loss, n, metric_args @functools.partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): loss_as_str = config.get('loss', 'softmax_xent') ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str) label_indices = config.get('label_indices') if label_indices: ens_logits = ens_logits[:, label_indices] losses = getattr(train_utils, loss_as_str)(logits=ens_logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [ens_logits, labels, ens_prelogits, mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @functools.partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): # Return shape [batch_size, representation_size * ensemble_size]. During # few-shot eval, a single linear regressor is applied over all dimensions. representation = [] for p in params.values(): _, outputs = model.apply({'params': flax.core.freeze(p)}, images, train=False) representation += [outputs[config.fewshot.representation_layer]] representation = jnp.concatenate(representation, axis=1) representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask write_note('Load checkpoints...') ensemble_params = load_checkpoints(config) write_note('Replicating...') ensemble_params = flax.jax_utils.replicate(ensemble_params) if jax.process_index() == 0: writer.write_hparams(dict(config)) write_note('Initializing few-shotters...') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} step = 1 # Report validation performance. write_note('Evaluating on the validation set...') for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(ensemble_params, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(ensemble_params, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip(probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state(batch_label_diversity) sample_diversity.update_state(batch_sample_diversity) ged.update_state(batch_ged) val_loss[val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # Entries in the ood_ds dict include: # (ind_dataset, ood_dataset1, ood_dataset2, ...). # OOD metrics are computed using ind_dataset paired with each of the # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds # is also included in the ood_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics(ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, ensemble_params, n_prefetch=config.get( 'prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. write_note('Few-shot evaluation...') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all(ensemble_params, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) write_note('Done!') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return val_loss, fewshot_results
def main(config, output_dir): # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=10) logging.info(config) acquisition_method = config.get('acquisition_method') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) writer.write_hparams(dict(config)) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note(f'Initializing for {acquisition_method}') # Download dataset data_builder = tfds.builder(config.dataset) data_builder.download_and_prepare() seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() val_ds = input_utils.get_data( dataset=config.dataset, split=config.val_split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Only repeat once. ) test_ds = input_utils.get_data( dataset=config.dataset, split=config.test_split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Only repeat once. ) # Init model if config.model_type == 'deterministic': model_utils = deterministic_utils reinit_params = config.get('model_reinit_params', ('head/kernel', 'head/bias')) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.get('model', {})) elif config.model_type == 'batchensemble': model_utils = batchensemble_utils reinit_params = ('batchensemble_head/bias', 'batchensemble_head/kernel', 'batchensemble_head/fast_weight_alpha', 'batchensemble_head/fast_weight_gamma') model = ub.models.PatchTransformerBE(num_classes=config.num_classes, **config.model) else: raise ValueError('Expect config.model_type to be "deterministic" or' f'"batchensemble", but received {config.model_type}.') init = model_utils.create_init(model, config, test_ds) rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) # Load the optimizer from flax. opt_name = config.get('optim_name') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) loaded_params = checkpoint_utils.load_checkpoint(tree=None, path=config.model_init) loaded = checkpoint_utils.restore_from_pretrained_params( params_cpu, loaded_params, config.model.representation_size, config.model.classifier, reinit_params, ) opt_cpu = opt_cpu.replace(target=loaded) # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being # donated otherwise. Ensure opt_cpu is really on the cpu this way. opt_cpu = jax.device_get(opt_cpu) update_fn = model_utils.create_update_fn(model, config) evaluation_fn = model_utils.create_evaluation_fn(model, config) # NOTE: We need this because we need an Id field of type int. # TODO(andreas): Rename to IdSubsetDatasetBuilder? pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder, subset_ids=None) rng, pool_ds_rng = jax.random.split(rng) # NOTE: below line is necessary on multi host setup # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index()) pool_train_ds = input_utils.get_data( dataset=pool_subset_data_builder, split=config.train_split, rng=pool_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, drop_remainder=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Don't repeat ) # Potentially acquire an initial training set. initial_training_set_size = config.get('initial_training_set_size', 10) if initial_training_set_size > 0: current_opt_repl = flax_utils.replicate(opt_cpu) pool_ids, _, _, pool_masks = get_ids_logits_masks( model=model, opt_repl=current_opt_repl, ds=pool_train_ds, config=config) rng, initial_uniform_rng = jax.random.split(rng) pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng) initial_training_set_batch_ids, _ = select_acquisition_batch_indices( acquisition_batch_size=initial_training_set_size, scores=pool_scores, ids=pool_ids, ignored_ids=set(), ) else: initial_training_set_batch_ids = [] # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU # then this dataset creation could be simplified. # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340 # CLU is explicitly not accepting outside contributions at the moment. train_subset_data_builder = al_utils.SubsetDatasetBuilder( data_builder, subset_ids=set(initial_training_set_batch_ids)) test_accuracies = [] training_sizes = [] rng, rng_loop = jax.random.split(rng) rngs_loop = flax_utils.replicate(rng_loop) if config.model_type == 'batchensemble': rngs_loop = {'dropout': rngs_loop} # TODO(joost,andreas): double check if below is still necessary # (train_split is independent of this) # NOTE: train_ds_rng is re-used for all train_ds creations rng, train_ds_rng = jax.random.split(rng) measurements = {} accumulated_steps = 0 while True: current_train_ds_length = len(train_subset_data_builder.subset_ids) if current_train_ds_length >= config.get('max_training_set_size', 150): break write_note(f'Training set size: {current_train_ds_length}') current_opt_repl = flax_utils.replicate(opt_cpu) # Only fine-tune if there is anything to fine-tune with. if current_train_ds_length > 0: # Repeat dataset to have oversampled epochs and bootstrap more batches number_of_batches = current_train_ds_length / config.batch_size num_repeats = math.ceil(config.total_steps / number_of_batches) write_note(f'Repeating dataset {num_repeats} times') # We repeat the dataset several times, such that we can obtain batches # of size batch_size, even at start of training. These batches will be # effectively 'bootstrap' sampled, meaning they are sampled with # replacement from the original training set. repeated_train_ds = input_utils.get_data( dataset=train_subset_data_builder, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), # TODO(joost,andreas): double check if below leads to bootstrap # sampling. num_epochs=num_repeats, ) # We use this dataset to evaluate how well we perform on the training set. # We need this to evaluate if we fit well within max_steps budget. train_eval_ds = input_utils.get_data( dataset=train_subset_data_builder, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, drop_remainder=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, ) # NOTE: warmup and decay are not a good fit for the small training set # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps, # **config.get('lr', {}) # ) lr_fn = lambda x: config.lr.base early_stopping_patience = config.get('early_stopping_patience', 15) current_opt_repl, rngs_loop, measurements = finetune( update_fn=update_fn, opt_repl=current_opt_repl, lr_fn=lr_fn, ds=repeated_train_ds, rngs_loop=rngs_loop, total_steps=config.total_steps, train_eval_ds=train_eval_ds, val_ds=val_ds, evaluation_fn=evaluation_fn, early_stopping_patience=early_stopping_patience, profiler=profiler) train_val_accuracies = measurements.pop('train_val_accuracies') current_steps = 0 for step, train_acc, val_acc in train_val_accuracies: writer.write_scalars(accumulated_steps + step, { 'train_accuracy': train_acc, 'val_accuracy': val_acc }) current_steps = step accumulated_steps += current_steps + 10 test_accuracy = get_accuracy(evaluation_fn=evaluation_fn, opt_repl=current_opt_repl, ds=test_ds) write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}') test_accuracies.append(test_accuracy) training_sizes.append(current_train_ds_length) pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks( model=model, opt_repl=current_opt_repl, ds=pool_train_ds, use_pre_logits=acquisition_method == 'density', config=config) if acquisition_method == 'uniform': rng_loop, rng_acq = jax.random.split(rng_loop, 2) pool_scores = get_uniform_scores(pool_masks, rng_acq) elif acquisition_method == 'entropy': pool_scores = get_entropy_scores(pool_outputs, pool_masks) elif acquisition_method == 'margin': pool_scores = get_margin_scores(pool_outputs, pool_masks) elif acquisition_method == 'density': if current_train_ds_length > 0: pool_scores = get_density_scores(model=model, opt_repl=current_opt_repl, train_ds=train_eval_ds, pool_pre_logits=pool_outputs, pool_masks=pool_masks, config=config) else: rng_loop, rng_acq = jax.random.split(rng_loop, 2) pool_scores = get_uniform_scores(pool_masks, rng_acq) else: raise ValueError('Acquisition method not found.') acquisition_batch_ids, _ = select_acquisition_batch_indices( acquisition_batch_size=config.get('acquisition_batch_size', 10), scores=pool_scores, ids=pool_ids, ignored_ids=train_subset_data_builder.subset_ids) train_subset_data_builder.subset_ids.update(acquisition_batch_ids) measurements.update({'test_accuracy': test_accuracy}) writer.write_scalars(current_train_ds_length, measurements) write_note(f'Final acquired training ids: ' f'{train_subset_data_builder.subset_ids}' f'Accuracies: {test_accuracies}') pool.close() pool.join() writer.close() # TODO(joost,andreas): save the final checkpoint return (train_subset_data_builder.subset_ids, test_accuracies)
def train_and_evaluate(config, workdir, strategy): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. strategy: Distribution strategy to use for distributing the model. """ tf.io.gfile.makedirs(workdir) tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0), 2) tf.random.set_seed(tf_rng.numpy()[0]) # Input pipeline. ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets( config, data_rng, strategy=strategy) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = (ds_info.splits["train"].num_examples // config.global_batch_size * config.num_epochs) steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 256 but scale it according to our batch size. base_learning_rate = config.learning_rate * config.global_batch_size / 256.0 # Initialize model. num_classes = ds_info.features["label"].num_classes if config.distill_teacher: do_distill = True teacher_file_list = (config.distill_teacher).split(",") teacher_models = load_teacher_models(teacher_file_list, num_classes, config, strategy) distill_params = {} distill_params["alpha"] = config.distill_alpha distill_params["beta"] = config.distill_fd_beta distill_params["teacher_model"] = TeacherModel(teacher_models, name="teacher") else: do_distill = False distill_params = None state = create_state(config, num_classes=num_classes, strategy=strategy) ckpt_manager = tf.train.CheckpointManager(checkpoint=state, directory=workdir, max_to_keep=5) if ckpt_manager.latest_checkpoint: state.restore(ckpt_manager.latest_checkpoint) logging.info("Restored from %s", ckpt_manager.latest_checkpoint) else: logging.info("Initializing from scratch.") initial_step = state.global_step.numpy().item() learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) writer = metric_writers.create_default_writer(workdir) writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): state.model.trainable = True # `step` is a Python integer. `global_step` is a TF variable on the # GPU/TPU devices. is_last_step = step == num_train_steps train_step(state, train_iter, config.weight_decay, learning_rate_fn, do_distill, distill_params, strategy) state.train_metrics.update_state_lr( learning_rate_fn(state.global_step.numpy().item())) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) report_progress(step) if step == initial_step: parameter_overview.log_parameter_overview(state.model) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, state.train_metrics.result()) state.train_metrics.reset_states() state.train_metrics.reset_lr() if step % config.eval_every_steps == 0 or is_last_step: state.model.trainable = False if config.dataset == "imagenet-lt": evaluate(state, val_ds, state.val_metrics, strategy) writer.write_scalars(step, state.val_metrics.result()) logging.info("Num val images %d", state.val_metrics.accuracy.count.numpy()) evaluate(state, test_ds, state.test_metrics, strategy) writer.write_scalars(step, state.test_metrics.result()) logging.info("Num test images %d", state.test_metrics.accuracy.count.numpy()) if step % config.checkpoint_every_steps == 0 or is_last_step: checkpoint_path = ckpt_manager.save(step) logging.info("Saved checkpoint %s", checkpoint_path) logging.info("Finishing training at step %d", step) logging.info("Saving the final weights") file_path = "%s/final_weights" % workdir state.model.save_weights(file_path, save_format="tf")
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning("%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState( step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer( batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like(previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError("A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError("A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats["patch_representations"] writer.write_histograms(step, { "patch_representations": patch_representations }) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, _, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = temperature_sampler.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") def encode_strings(strs, max_len): tokenized_batch = np.zeros((len(strs), max_len), np.int32) for i, s in enumerate(strs): toks = encoder.tokenize(s).numpy() # Remove EOS token in prompt. tokenized_batch[i, :toks.shape[0]-1] = toks[:-1] return tokenized_batch tokenized_prompts = encode_strings( [config.prompts], config.max_predict_length) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) m = models.TransformerLM(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, temperature=config.sampling_temperature, top_k=config.sampling_top_k), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip( jnp.exp(summary["loss"]), a_max=1.0e4) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip( jnp.exp(eval_results["loss"]), a_max=1.0e4) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("generate_text"): exemplars = generate_prediction( p_pred_step=p_pred_step, target=optimizer.target, tokenized_prompts=tokenized_prompts, eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id() == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step)
def train(*, workdir, initial_step, chkpt_manager, Phi, Psi, optimal_subspace, num_epochs, learning_rate, key, method, lissa_kappa, optimizer, covariance_batch_size, main_batch_size, weight_batch_size, estimate_feature_norm=True): """Training function. For lissa, the total number of samples is 2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size. Args: workdir: Work directory, where we'll save logs. initial_step: Initial step chkpt_manager: Checkpoint manager. Phi: The initial feature matrix. Psi: The target matrix whose PCA is to be determined. optimal_subspace: Top-d left singular vectors of Psi. num_epochs: How many gradient steps to perform. (Not really epochs) learning_rate: The step size parameter for sgd. key: The jax prng key. method: 'naive', 'lissa', or 'oracle'. lissa_kappa: The parameter of the lissa method, if used. optimizer: Which optimizer to use. Only 'sgd' is supported. covariance_batch_size: the 'J' parameter. For the naive method, this is how many states we sample to construct the inverse. For the lissa method, ditto -- these are also "iterations". main_batch_size: How many states to update at once. weight_batch_size: How many states to construct the weight vector. estimate_feature_norm: Whether to use a running average of the max feature norm rather than the real maximum. Returns: tuple: representation and gradient arrays """ # Don't overwrite Phi. Phi = np.copy(Phi) Phis = [np.copy(Phi)] num_states, d = Phi.shape _, num_tasks = Psi.shape # Keep a running average of the max norm of a feature vector. None means: # don't do it. if estimate_feature_norm: estimated_feature_norm = utils.compute_max_feature_norm(Phi) else: estimated_feature_norm = None # Create an explicit weight vector (needed for explicit method). key, weight_key = jax.random.split(key) explicit_weight_matrix = np.array( jax.random.normal( # charlinel(why benefit of np?) weight_key, (d, num_tasks), dtype=jnp.float64)) assert optimizer == 'sgd', 'Non-sgd not yet supported.' writer = metric_writers.create_default_writer(logdir=str(workdir), ) hooks = [ periodic_actions.PeriodicCallback( every_steps=5_000, callback_fn=lambda step, t: chkpt_manager.save((step, Phi))) ]
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ is_first_process = jax.process_index() == 0 tf.io.gfile.makedirs(workdir) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets( config) config.seq_length = 250 vocab_size = int(encoder.vocab_size()) config.num_classes = vocab_size config.data_shape = (config.seq_length, 1) logging.info('Training with vocab size %d', vocab_size) def decode_tokens(toks): return encoder.detokenize(toks) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) config.per_device_batch_size = config.batch_size // jax.process_count() logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- model, initial_variables = model_setup(init_rng, config) # Instead of passing the optimizer fns directly, we use a fn that returns # the optimizer given a learning rate. def tx_fn(lr): return optax.adamw( lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create( params=initial_variables['params'], tx_fn=tx_fn) # We access model params only from state below via state.params. del initial_variables if config.restore_checkpoints: # Restore unreplicated model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=not is_first_process) if start_step == 0: config_dict = dict(config) writer.write_hparams(config_dict) if is_first_process and start_step == 0: # Dump config file to work dir for easy model loading. config_path = os.path.join(workdir, 'config') with tf.io.gfile.GFile(config_path, 'wb') as fp: pickle.dump(config, fp) print('Using state', type(state)) # Replicate state. state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( factors=config.lr_factors, base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # Compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, clip_grad=config.clip_grad, ema_momentum=config.get('ema_momentum', 0.999)), axis_name='batch', in_axes=(0, 0), donate_argnums=(0,)) p_eval_step = jax.pmap( functools.partial( eval_step, model=model), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of train PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rng = jax.random.fold_in(rng, jax.process_index()) rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5) train_rngs = jax.random.split(rng1, jax.local_device_count()) eval_rngs = jax.random.split(rng2, jax.local_device_count()) test_rngs = jax.random.split(rng3, jax.local_device_count()) del rng, rng1, rng2, rng3 logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if is_first_process: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] # Iterator that does epoch-wise indefinite iteration. def iterate_train(train_ds): epoch = 1 while True: msg = f'Starting epoch {epoch}' logging.info(msg) for batch in train_ds: yield batch epoch += 1 train_iter = iterate_train(train_ds) kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] with metric_writers.ensure_flushes(writer): step = start_step for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) state, metrics = p_train_step( state, batch, rng=train_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if step > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) # First handle loss terms per step. t_batch = train_metrics.pop('t_batch') nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch') kl_tracker_train.update( t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1)) kl_values = kl_tracker_train.get_kl_per_t() kl_history.append(np.array(kl_values)) kl_history = kl_history[-100:] # Keep last 100 items only. # Handle remaining `standard` metrics summary = jax.tree_map(jnp.mean, train_metrics) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=eval_ds, rng=eval_rngs) writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) test_results, test_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=test_ds, rng=test_rngs) writer.write_scalars( step, {'test_' + k: v for k, v in test_results.items()}) if step == 1000 or (step > 0 and step % config.detailed_eval_every_steps == 0): if is_first_process: loss_components_path = os.path.join(workdir, 'loss_components') with tf.io.gfile.GFile(loss_components_path, 'wb') as fp: pickle.dump(kl_history[-1], fp) extensive_eval_rngs = extensive_eval( config, extensive_eval_rngs, writer, workdir, model, state, kl_history, test_ds, step, decode_tokens) with report_progress.timed('generate_text'): generate_prediction(sample_rng, config, model, state, writer, decode_tokens, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step > 0 and (step % config.checkpoint_every_steps == 0 or is_last_step)) if config.save_checkpoints and save_checkpoint and is_first_process: with report_progress.timed('checkpoint'): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state), step, overwrite=True)
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ # Seed for reproducibility. rng = jax.random.PRNGKey(config.rng_seed) # Set up logging. summary_writer = metric_writers.create_default_writer(workdir) summary_writer.write_hparams(dict(config)) # Get datasets. rng, dataset_rng = jax.random.split(rng) dataset = input_pipeline.get_dataset(config, dataset_rng) graph, labels, masks = jax.tree_map(jnp.asarray, dataset) labels = jax.nn.one_hot(labels, config.num_classes) train_mask = masks['train'] train_indices = jnp.where(train_mask)[0] train_labels = labels[train_indices] num_training_nodes = len(train_indices) # Get subgraphs. if config.differentially_private_training: graph = jax.tree_map(np.asarray, graph) subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to) graph = jax.tree_map(jnp.asarray, graph) # We only need the subgraphs for training nodes. train_subgraphs = subgraphs[train_indices] del subgraphs else: train_subgraphs = None # Initialize privacy accountant. training_privacy_accountant = privacy_accountants.get_training_privacy_accountant( config, num_training_nodes, compute_max_terms_per_node(config)) # Construct and initialize model. rng, init_rng = jax.random.split(rng) estimation_indices = get_estimation_indices(train_indices, config) state = create_train_state(init_rng, config, graph, train_labels, train_subgraphs, estimation_indices) # Set up checkpointing of the model. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Log overview of parameters. parameter_overview.log_parameter_overview(state.params) # Log metrics after initialization. logits = compute_logits(state, graph) metrics_after_init = compute_metrics(logits, labels, masks) metrics_after_init['epsilon'] = 0 log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init') # Train model. rng, train_rng = jax.random.split(rng) max_training_epsilon = get_max_training_epsilon(config) # Hooks called periodically during training. report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_training_steps, writer=summary_writer) profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) hooks = [report_progress, profiler] for step in range(initial_step, config.num_training_steps): # Perform one step of training. with jax.profiler.StepTraceAnnotation('train', step_num=step): # Sample batch. step_rng = jax.random.fold_in(train_rng, step) indices = jax.random.choice(step_rng, num_training_nodes, (config.batch_size, )) # Compute gradients. if config.differentially_private_training: grads = compute_updates_for_dp(state, graph, train_labels, train_subgraphs, indices, config.adjacency_normalization) else: grads = compute_updates(state, graph, train_labels, indices) # Update parameters. state = update_model(state, grads) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 10, step) for hook in hooks: hook(step) # Evaluate, if required. is_last_step = (step == config.num_training_steps - 1) if step % config.evaluate_every_steps == 0 or is_last_step: with report_progress.timed('eval'): # Check if privacy budget exhausted. training_epsilon = training_privacy_accountant(step + 1) if max_training_epsilon is not None and training_epsilon >= max_training_epsilon: break # Compute metrics. logits = compute_logits(state, graph) metrics_during_training = compute_metrics( logits, labels, masks) metrics_during_training['epsilon'] = training_epsilon log_metrics(step, metrics_during_training, summary_writer) # Checkpoint, if required. if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(state) return state
def train_and_evaluate(config, work_dir, try_checkpoint=True): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. work_dir: Directory where the tensorboard summaries are written to. try_checkpoint: Should try to load checkpoint (usually enabled, practical for debugging purposes to disable). Returns: The train state (which includes the `.params`). """ # Init rng key. msg = f'Running with seed {config.seed}.' logging.info(msg) rng = jax.random.PRNGKey(config.seed) data_rng, rng = jax.random.split(rng) is_first_host = jax.process_index() == 0 train_ds, test_ds, shape, num_classes = datasets.get_dataset( config, data_rng) # config.mask_shape = mask_shape config.data_shape = shape config.num_classes = num_classes writer = metric_writers.create_default_writer( work_dir, just_logging=jax.process_index() > 0) rng, init_rng = jax.random.split(rng) # Create output directory for saving samples. output_path = work_dir tf.io.gfile.makedirs(output_path) model, variables = model_setup(init_rng, config) # From now on we want different rng across hosts: rng = jax.random.fold_in(rng, jax.process_index()) tx = optax.adam(config.learning_rate, b1=0.9, b2=config.beta2, eps=1e-08, eps_root=0.0) state = custom_train_state.TrainState.create(params=variables['params'], tx=tx) if try_checkpoint: state, start_epoch = checkpoint.restore_from_path(work_dir, state) if start_epoch is None: start_epoch = 1 else: # For debugging we start at zero, so we immediately do detailed eval. start_epoch = 0 if is_first_host and start_epoch == 1: config_dict = dict(config) writer.write_hparams(config_dict) if is_first_host and start_epoch in (0, 1): # Dump config file to work dir for easy model loading. config_path = os.path.join(work_dir, 'config') with tf.io.gfile.GFile(config_path, 'wb') as fp: pickle.dump(config, fp) test_rng, train_rng = jax.random.split(rng) kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] p_train_step = jax.pmap(functools.partial(train_step, model=model, config=config), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, 0, None), donate_argnums=(2, )) # The only axes that are broadcasted are the in- and output rng key ones. The # rng is the first arg, and the last return value. p_eval_step = jax.pmap(functools.partial(eval_step, model=model), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, None)) # Replicate state. state = flax.jax_utils.replicate(state) with metric_writers.ensure_flushes(writer): for epoch in range(start_epoch, config.num_epochs + 1): # Train part. state, train_metrics, train_rng = train_epoch( p_train_step, state, train_ds, config.batch_size, epoch, train_rng, kl_tracker_train) # Val part. eval_metrics, test_rng = eval_model(p_eval_step, test_rng, state, test_ds, epoch) # Metric logging. if is_first_host: log_standard_metrics(writer, train_metrics, eval_metrics, epoch) kl_values = kl_tracker_train.get_kl_per_t() kl_history.append(np.array(kl_values)) # Prune to avoid too much memory consumption. kl_history = kl_history[-50:] if epoch == 15 or epoch % config.detailed_eval_every == 0: if is_first_host: loss_components_path = os.path.join( work_dir, 'loss_components') with tf.io.gfile.GFile(loss_components_path, 'wb') as fp: pickle.dump(kl_history[-1], fp) test_rng = extensive_eval(config, test_rng, writer, output_path, model, state, kl_history, test_ds, epoch) # Save to checkpoint. if is_first_host and epoch % config.save_every == 0: # Save to epoch + 1 since current epoch has just been completed. logging.info('saving checkpoint') checkpoint.save_checkpoint( work_dir, state=flax.jax_utils.unreplicate(state), step=epoch + 1, keep=2) logging.info('finished saving checkpoint') return state
def main(_): if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching: raise ValueError( "'precrop_iters has no effect when 'batching' the dataset") assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") os.makedirs(FLAGS.model_dir, exist_ok=True) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5) rngs = common_utils.shard_prng_key(step_rng) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, rng=data_rng, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets *_, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets iter_render_ds = zip(range(num_poses), render_ds) iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds) iter_test_ds = zip(range(test_items), test_ds) img_h, img_w, _ = hwf logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) start_step = int(state.step) # cycle already seen examples if resuming from checkpoint # (only useful for ensuring deterministic dataset, slow for large start_step) # if start_step != 0: # for _ in range(start_step): # _ = next(train_ds) # parameter_overview.log_parameter_overview(state.optimizer_coarse.target) # if FLAGS.config.num_importance > 0: # parameter_overview.log_parameter_overview(state.optimizer_fine.target) state = jax.device_put_replicated(state, jax.local_devices()) ### Build "pmapped" functions for distributed training train_fn = functools.partial(train_step, near, far, FLAGS.config, schedule_fn) p_train_step = jax.pmap( train_fn, axis_name="batch", in_axes=(0, 0, None, 0), # donate_argnums=(0, 1, 2), ) def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) # TODO: add hparams writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) logging.info("Starting training loop.") hooks = [] profiler = periodic_actions.Profile(num_profile_steps=5, logdir=FLAGS.model_dir) report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.config.num_steps, writer=writer) if jax.process_index() == 0: hooks += [profiler, report_progress] train_metrics = [] gen_video_ = functools.partial(gen_video, FLAGS.model_dir) for step in range(start_step, FLAGS.config.num_steps + 1): is_last_step = step == FLAGS.config.num_steps batch = next(train_ds) coords = None if not FLAGS.config.batching: coords = jnp.meshgrid(jnp.arange(img_h), jnp.arange(img_w), indexing="ij") if step < FLAGS.config.precrop_iters: dH = int(img_h // 2 * FLAGS.config.precrop_frac) dW = int(img_w // 2 * FLAGS.config.precrop_frac) coords = jnp.meshgrid( jnp.arange(img_h // 2 - dH, img_h // 2 + dH), jnp.arange(img_w // 2 - dW, img_w // 2 + dW), indexing="ij", ) coords = jnp.stack(coords, axis=-1).reshape([-1, 2]) with jax.profiler.StepTraceAnnotation("train", step_num=step): state, metrics = p_train_step(batch, state, coords, rngs) train_metrics.append(metrics) logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) _ = [h(step) for h in hooks] ### Write train summaries to TB if step % FLAGS.config.i_print == 0 or is_last_step: with report_progress.timed("training_metrics"): train_metrics = common_utils.get_metrics(train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) summary = {f"train/{k}": v for k, v in train_summary.items()} writer.write_scalars(step, summary) train_metrics = [] ### Eval a random validation image and plot it to TB if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step: with report_progress.timed("validation"): inputs = next(val_ds) rays, padding = prepare_render_data(inputs["rays"]._numpy()) outputs = p_eval_step(state, rays) preds, preds_c, z_std = jax.tree_map( lambda x: to_np(x, hwf, padding), outputs) loss = np.mean((preds["rgb"] - inputs["image"])**2) summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) summary = { "val/rgb": to_rgb(preds["rgb"]), "val/target": to_np(inputs["image"], hwf, padding), "val/disp": disp_post(preds["disp"], FLAGS.config), "val/acc": preds["acc"], } if FLAGS.config.num_importance > 0: summary["val/rgb_c"] = to_rgb(preds_c["rgb"]) summary["val/disp_c"] = disp_post(preds_c["disp"], FLAGS.config) summary["val/z_std"] = z_std writer.write_images(step, summary) ### Render a video with test poses if step % FLAGS.config.i_video == 0 and step > 0: with report_progress.timed("video_render"): logging.info("Rendering video at step %d", step) rgb_list = [] disp_list = [] for idx, inputs in tqdm(iter_render_ds, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) gen_video_(np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video_(disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.config.use_viewdirs: rgb_list = [] for idx, inputs in tqdm(iter_vdirs_ds, desc="Viewdirs render"): rays, padding = prepare_render_data( inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) rgb_list.append(to_np(preds["rgb"], r_hwf, padding)) gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step) ### Save images in the test set if step % FLAGS.config.i_testset == 0 and step > 0: with report_progress.timed("test_render"): logging.info("Rendering test set at step %d", step) test_losses = [] for idx, inputs in tqdm(iter_test_ds, desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 0: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 0: loss = np.mean(test_losses) summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) writer.flush() ### Save ckpt if step % FLAGS.config.i_weights == 0 or is_last_step: with report_progress.timed("checkpoint"): save_checkpoint(state, FLAGS.model_dir)