def test_get_weight_decay_fn( self, weight_decay_rules, rescale_value, learning_rate, input_params, expected_decayed_params): weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules, rescale_value) actual_decayed_params = weight_decay_fn(input_params, learning_rate) self.assertAllClose(actual_decayed_params, expected_decayed_params)
def update_fn(opt, lr, images, labels, rngs): return update_fn_be( opt=opt, rngs=rngs, lr=lr, images=images, labels=labels, batch_loss_fn=batch_loss_fn, weight_decay_fn=train_utils.get_weight_decay_fn( weight_decay_rules=config.get('weight_decay', []) or [], rescale_value=config.lr.base if config.get('weight_decay_decouple') else 1.), max_grad_norm_global=config.get('grad_clip_norm', None), fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier', None))
def main(config, output_dir): seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) return (val_iter, val_steps) val_iter_splits = { 'val': _get_val_split(config.dataset, split=config.val_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) } if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_iter_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_iter_imagenet_real, val_steps = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) val_iter_splits['imagenet_real'] = (val_iter_imagenet_real, val_steps) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_dataset')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = int(ntrain_img / batch_size) if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.get('model', {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['head']['bias'] = jnp.full_like(params['head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['head']['kernel'] = jnp.full_like(params['head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) label_indices = config.get('label_indices') logging.info('!!! mask %s, label_indices %s', mask, label_indices) if label_indices: logits = logits[:, label_indices] # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels[:, :( len(label_indices) if label_indices else config.num_classes)], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): _, outputs = model.apply({'params': flax.core.freeze(params)}, images, train=False) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements rng, train_loop_rngs = jax.random.split(rng) reint_params = ('head/kernel', 'head/bias') if config.get('only_eval', False) or not config.get('reint_head', True): reint_params = [] checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=None, default_reinit_params=reint_params, config=config, ) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_iter_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # NOTE: Validation eval is only run on certain steps, so determine how many # times it was run previously. num_val_runs = sum( map( lambda i: train_utils.itstime(i, config.log_eval_steps, total_steps), range(1, first_step + 1))) for val_name, (val_iter, val_steps) in val_iter_splits.items(): val_iter = itertools.islice(val_iter, num_val_runs * val_steps, None) val_iter_splits[val_name] = (val_iter, val_steps) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): if not config.get('only_eval', False): opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn( opt_repl, lr_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if not config.get('only_eval', False) and train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( train_loop_rngs=train_loop_rngs, optimizer=opt_cpu, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if not config.get('only_eval', False) and train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, (val_iter, val_steps) in val_iter_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. ncorrect, loss, nseen = 0, 0, 0 for _, batch in zip(range(val_steps), val_iter): if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip( probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state( batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[ f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[ f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # Entries in the ood_ds dict include: # (ind_dataset, ood_dataset1, ood_dataset2, ...). # OOD metrics are computed using ind_dataset paired with each of the # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds # is also included in the ood_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, opt_repl) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results
def main(argv): del argv # unused arg config = FLAGS.config # Unpack total and warmup steps # TODO(nband): revert this to separate arguments. total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift) # LR / Optimization Flags batch_size = config.batch_size grad_clip_norm = config.grad_clip_norm weight_decay = config.weight_decay print('Standard wandb hyperparameters:') print({ 'batch_size': batch_size, 'grad_clip_norm': grad_clip_norm, 'weight_decay': weight_decay, 'total_steps': config.total_steps, 'lr': config.lr }) print('SNGP Params:', config.gp_layer) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.seed rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): tf.io.gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d dev per host (%d dev total), that is a %d per-device batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get( dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint: disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:]) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img / batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get('total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) # Specify Gaussian process layer configs. gp_config = config.get('gp_layer', {}) model_dict = vit_utils.initialize_model('sngp', config) model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) variables = model.init(rng, dummy_input, train=False) # Split model parameters into trainable and untrainable collections. states, params = variables.pop('params') del variables # Set bias in the head to a low value, such that loss is small initially. params = flax.core.unfreeze(params) if use_gp_layer: # Modify the head parameter in the GP head. params['head']['output_layer']['bias'] = jnp.full_like( params['head']['output_layer']['bias'], config.get('init_head_bias', 0)) else: params['head']['bias'] = jnp.full_like( params['head']['bias'], config.get('init_head_bias', 0)) return params, states rng, rng_init = jax.random.split(rng) params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, states, images, labels): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply( variable_dict, images, train=False, mean_field_factor=gp_config.get('mean_field_factor', -1.)) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ logits, labels, out['pre_logits']], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,)) def update_fn(opt, states, lr, reset_covmat, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, states, images, labels): # Specify mutable collection to update untrainable GP parameters. variable_dict = {'params': flax.core.freeze(params), **states} model_results, updated_states = model.apply( variable_dict, images, train=True, rngs={'dropout': rng_model_local}, mutable=list(states.keys()), mean_field_factor=gp_config.get('mean_field_factor', -1.)) logits, _ = model_results loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels) return loss, updated_states # Performs exact covariance update (i.e., reset precision matrix resetting # at begining of new epoch) if covmat_momentum is a null value. if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0: # Resets precision matrix to Identity * ridge_penalty if at the begining # of a new epoch. This should be done before accumulate gradient. ridge_penalty = gp_config.get('ridge_penalty', 1.) prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] prec_mat_new = ( (1. - reset_covmat) * prec_mat_old + reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty) states = flax.core.unfreeze(states) states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] = prec_mat_new states = flax.core.freeze(states) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, s), g = vit_utils.accumulate_gradient_with_states( jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat return opt, s, l, rng, measurements # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias', 'head/kernel', 'head/bias') rng, train_loop_rngs = jax.random.split(rng) checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=states_cpu, default_reinit_params=default_reinit_params, config=config) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer states_cpu = checkpoint_data.fixed_model_states accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar( map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) # Prepare the precision matrix resetting schedule, and pre-fetch it to device. reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0) reset_covmat_iter = train_utils.prefetch_scalar( map(reset_covmat_fn, range(first_step, total_steps)), nprefetch=config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) states_repl = flax.jax_utils.replicate(states_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # val_loss = -jnp.inf # results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. # Makes sure log_eval_steps is same as steps_per_epoch. This is because # the precision matrix needs to be updated fully (at the end of each epoch) # when eval takes place. log_eval_steps = steps_per_epoch if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl, reset_covmat_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter, reset_covmat_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, train_loop_rngs, extra_measurements) = update_fn( opt_repl, states_repl, lr_repl, reset_covmat_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout(checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # For GP layer, we will also do the same for untrainable parameters # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, fixed_model_states=states_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update(flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn( opt_repl.target, states_repl, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True ) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results( output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() if wandb_run is not None: wandb_run.finish()
def main(_): config = FLAGS.config # Unpack total and warmup steps total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names( dist_shift) # LR / Optimization Flags print('wandb hyperparameters:') print({ 'batch_size': config.batch_size, 'grad_clip_norm': config.grad_clip_norm, 'weight_decay': config.weight_decay, 'total_steps': config.total_steps, 'lr': config.lr, 'fast_weight_lr_multiplier': config.fast_weight_lr_multiplier }) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) tf.io.gfile.makedirs(output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get(dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint:disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img // batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') model_dict = vit_utils.initialize_model('batchensemble', config) model = model_dict['model'] ens_size = model_dict['ens_size'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['batchensemble_head']['bias'] = jnp.full_like( params['batchensemble_head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['batchensemble_head']['kernel'] = jnp.full_like( params['batchensemble_head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels): tiled_logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) loss_name = config.get('loss', 'sigmoid_xent') # TODO(dusenberrymw,zmariet): Clean up and generalize this. if loss_name == 'sigmoid_xent': ens_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) else: # softmax ens_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) losses = getattr(train_utils, loss_name)(logits=ens_logits, labels=labels[:, :config.num_classes], reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) def batch_loss_fn(params, images, labels, rngs): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs=rngs) labels = jnp.tile(labels, (ens_size, 1)) loss_fn = getattr(train_utils, config.get('loss', 'sigmoid_xent')) loss = jnp.mean(loss_fn(logits=logits, labels=labels)) return loss, dict() @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0, 1)) def update_fn(opt, rngs, lr, images, labels): return batchensemble_utils.update_fn_be( opt=opt, rngs=rngs, lr=lr, images=images, labels=labels, batch_loss_fn=batch_loss_fn, weight_decay_fn=weight_decay_fn, max_grad_norm_global=config.get('grad_clip_norm', None), fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier', None)) # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path reint_params = ('batchensemble_head/bias', 'batchensemble_head/kernel', 'batchensemble_head/fast_weight_alpha', 'batchensemble_head/fast_weight_gamma') if config.get('only_eval', False) or not config.get('reint_head', True): reint_params = [] checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=rng, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=None, default_reinit_params=reint_params, config=config) train_loop_rngs = {'dropout': checkpoint_data.train_loop_rngs} opt_cpu = checkpoint_data.optimizer accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # eval_loss = { # eval_name: -jnp.inf for eval_name, _ in eval_iter_splits.items()} # fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) # TODO(zmariet): Find better way to cut down iteration advancement cost. if not config.get('disable_preemption_reproducibility', False): train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): if not config.get('only_eval', False): opt_repl, train_loop_rngs, extra_measurements = update_fn( opt_repl, train_loop_rngs, lr_repl, train_batch['image'], train_batch['labels']) if jax.process_index() == 0: profiler(step) # Checkpoint saving if not config.get('only_eval', False) and train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( train_loop_rngs=train_loop_rngs, optimizer=opt_cpu, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if not config.get('only_eval', False) and train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Keep to return for reproducibility tests. # train_loss = train_measurements['training_loss'] # Report validation performance if config.get('only_eval', False) or train_utils.itstime( step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation sets...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn(opt_repl.target, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results(output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break if config.get('only_eval', False): break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close()
def create_update_fn(model, config): """Create the update function from model and config. Args: model: The model to be used in updates. config: The config of the experiment. Returns: The function that updates the model for one step. """ weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) logging.info('weight_decay_rules = %s', weight_decay_rules) logging.info('rescale_value = %s', rescale_value) @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,)) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Split rng and return next_rng for the following step. rng, next_rng = jax.random.split(rng, 2) rng_local = jax.random.fold_in(rng, jax.lax.axis_index('batch')) logging.info(msg=f'images in loss_fn = {jnp.shape(images)}') logging.info(msg=f'labels in loss_fn = {jnp.shape(labels)}') def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_local}) logging.info(msg=f'logits={logits}') label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels) return loss, logits # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, logits), g = train_utils.accumulate_gradient( jax.value_and_grad(loss_fn, has_aux=True), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') measurements['training_loss'] = l logging.info(msg=f'measurements = {measurements}') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g logging.info(msg=f'measurements = {measurements}') # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) top1_idx = jnp.argmax(logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] prec1 = jax.lax.psum( jnp.sum(top1_correct), axis_name='batch') / config.batch_size measurements['training_prec@1'] = prec1 measurements['learning_rate'] = lr return opt, next_rng, measurements return update_fn