def create_train_state( config, rng, learning_rate_fn, example_batch): """Create and initialize the model. Args: config: Configuration for model. rng: JAX PRNG Key. learning_rate_fn: learning rate function example_batch: for model intialization Returns: The initialized TrainState with the optimizer. """ model, variables = create_model(config, rng, example_batch) params = variables['params'] parameter_overview.log_parameter_overview(params) state = train_state.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=optax.adamw( learning_rate=learning_rate_fn, b1=0.9, b2=.98, eps=1e-9, weight_decay=config.train.weight_decay), ) return model, state
def test_count_parameters_on_module_with_duplicate_names(self): module = snt.Module() # Weights of a 2D convolution with 2 filters.. module.conv = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv") module.conv(tf.ones((2, 5, 5, 3))) # 3 * 3*3 * 2 + 2 (bias) = 56 parameters module.conv2 = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv") module.conv2(tf.ones( (2, 5, 5, 3))) # 3 * 3*3 * 2 + 2 (bias) = 56 parameters parameter_overview.log_parameter_overview(module) self.assertEqual(112, parameter_overview.count_parameters(module))
def __init__(self, mode, config): """Initializes experiment.""" self.mode = mode self.config = config self.vdvae = vdvae.Vdvae(config) self.rng = jax.random.PRNGKey(config.seed) # setup eval _, self._eval_ds = dataset.create_eval_dataset( config.data.task, config.evaluation.batch_size, config.evaluation.subset) self.rng, eval_rng = jax.random.split(self.rng) self._eval_batch = functools.partial(self._eval_batch, base_rng=eval_rng) self._eval_batch = jax.pmap(self._eval_batch, axis_name='batch') if mode == 'train': self.rng, data_rng = jax.random.split(self.rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) _, train_ds = dataset.create_train_dataset( config.data.task, config.training.batch_size, config.training.substeps, data_rng) self._train_iter = iter(train_ds) self.rng, init_rng, sample_rng = jax.random.split(self.rng, num=3) input_shape = tuple(train_ds.element_spec.shape[2:]) params = self.vdvae.init(init_rng, sample_rng, input_shape[0], jnp.ones(input_shape, dtype=jnp.uint8)) parameter_overview.log_parameter_overview(params) # Use the same rng to init with the same params. ema_params = jax.tree_map(jnp.array, params) opt_init, _ = self.optimizer( learning_rate=self.config.optimizer.base_learning_rate) opt_state = opt_init(params) # create train state self._train_state = TrainState(step=0, params=params, ema_params=ema_params, opt_state=opt_state) self.rng, update_rng = jax.random.split(self.rng) self._update_func = functools.partial(self._update_func, update_rng) self._update_func = functools.partial(jax.lax.scan, self._update_func) self._update_func = jax.pmap(self._update_func, axis_name='batch')
def create_train_state(model, config, rng, inputs): """Create and initialize the model. Args: model: Flax nn module. config: Configuration for model. rng: JAX PRNG Key. inputs: The init inputs fed into the model. Returns: The initialized TrainState with the optimizer. """ rng, params_rng, dropout_rng = jax.random.split(rng, 3) variables = model.init({ 'params': params_rng, 'dropout': dropout_rng }, inputs) params = variables['params'] parameter_overview.log_parameter_overview(params) # pytype: disable=wrong-arg-types if config.optimizer == 'AdamW': def w_decay_fn(path): return all(pn not in path for pn in config.no_weight_decay) def wo_decay_fn(path): return any(pn in path for pn in config.no_weight_decay) optimizer_w_decay = flax.optim.Adam( learning_rate=config.learning_rate, weight_decay=config.learning_rate_weight_decay) optimizer_wo_decay = flax.optim.Adam( learning_rate=config.learning_rate, weight_decay=0) params_w_decay = flax.optim.ModelParamTraversal( lambda path, _: w_decay_fn(path)) params_wo_decay = flax.optim.ModelParamTraversal( lambda path, _: wo_decay_fn(path)) optimizer = flax.optim.MultiOptimizer( (params_w_decay, optimizer_w_decay), (params_wo_decay, optimizer_wo_decay)).create(params) else: raise NotImplementedError return TrainState(step=0, optimizer=optimizer)
def create_train_state(config, rng, input_shape, num_classes): """Create and initialize the model. Args: config: Configuration for model. rng: JAX PRNG Key. input_shape: Shape of the inputs fed into the model. num_classes: Number of classes in the output layer. Returns: The model and initialized TrainState with the optimizer. """ # Use of getattr could simplify this but is discouraged by the style guide. We # should consider using it anyway if sequence of ifs grows too big. if config.model_name == "tiny_classifier": get_model = models.tiny_classifier elif config.model_name == "spin_classifier_6_layers": get_model = models.spin_classifier_6_layers elif config.model_name == "spherical_classifier_6_layers": get_model = models.spherical_classifier_6_layers elif config.model_name == "cnn_classifier_6_layers": get_model = models.cnn_classifier_6_layers else: raise ValueError(f"Model {config.model_name} not supported.") model = get_model(num_classes, axis_name=_PMAP_AXIS_NAME) variables = model.init(rng, jnp.ones(input_shape), train=False) params = variables["params"] batch_stats = variables.get("batch_stats", {}) abs_if_complex = lambda x: jnp.abs(x) if x.dtype == jnp.complex64 else x parameter_overview.log_parameter_overview( jax.tree_util.tree_map(abs_if_complex, params)) optimizer = flax.optim.Adam().create(params) return model, TrainState(step=0, optimizer=optimizer, batch_stats=batch_stats)
def train_and_evaluate(config, workdir): """Performs training and evaluation with the given configuration.""" # Set up logging. summary_writer = metric_writers.create_default_writer(workdir) summary_writer.write_hparams(config.to_dict()) time_delta = config.time_delta train_time_jump_range = config.train_time_jump_range test_time_jumps = config.test_time_jumps num_trajectories = config.num_trajectories num_samples = config.num_samples train_split_proportion = config.train_split_proportion num_train_steps = config.num_train_steps regularizations = config.regularizations.to_dict() eval_cadence = config.eval_cadence # Get simulation functions. generate_canonical_coordinates_fn = get_generate_canonical_coordinates_fn( config) compute_hamiltonian_fn = get_compute_hamiltonian_fn(config) # Generate data. rng = jax.random.PRNGKey(config.rng_seed) rng, simulation_parameters_rng = jax.random.split(rng) simulation_parameters = sample_simulation_parameters( config.simulation_parameter_ranges.to_dict(), num_trajectories, simulation_parameters_rng) times = jnp.arange(num_samples) * time_delta all_positions, all_momentums = generate_canonical_coordinates_fn( times, simulation_parameters) # Train-test split. if config.split_on == 'times': num_train_samples = int(num_samples * train_split_proportion) train_positions = all_positions[:num_train_samples] test_positions = all_positions[num_train_samples:] train_momentums = all_momentums[:num_train_samples] test_momentums = all_momentums[num_train_samples:] train_simulation_parameters = simulation_parameters test_simulation_parameters = simulation_parameters else: raise ValueError(f'Unsupported feature for split: {config.split_on}.') # Rescale. scaler = create_scaler(config) scaler = fit_scaler(train_positions, train_momentums, scaler) train_positions, train_momentums = transform_with_scaler( train_positions, train_momentums, scaler) test_positions, test_momentums = transform_with_scaler( test_positions, test_momentums, scaler) # Initialize model. state = create_train_state( config, rng, (train_positions[:1], train_momentums[:1], time_delta)) best_state = state parameter_overview.log_parameter_overview(state.params) # Setup for coordinates and time deltas. coordinates_fn = get_coordinates_fn(config) time_deltas_fn = get_time_deltas_fn(config) # Setup sampling for time jumps. sample_time_jump_fn = functools.partial( sample_time_jump_with_linear_increase, min_jump=train_time_jump_range[0], max_jump=train_time_jump_range[1], num_train_steps=num_train_steps) min_train_loss = jnp.inf all_train_metrics = {} all_test_metrics = {} for step in range(num_train_steps): step_rng = jax.random.fold_in(rng, step) # Sample time jump. step_rng, jump_rng = jax.random.split(step_rng) jump = sample_time_jump_fn(step, rng=jump_rng) # Setup inputs and targets on all trajectories. train_curr_positions, train_curr_momentums, train_target_positions, train_target_momentums = coordinates_fn( train_positions, train_momentums, jump) time_deltas = time_deltas_fn(jump) # Sample indices. num_samples_on_trajectory = train_curr_positions.shape[0] sample_indices = jax.random.choice(step_rng, num_samples_on_trajectory, (config.batch_size, )) batch_curr_positions = train_curr_positions[sample_indices] batch_curr_momentums = train_curr_momentums[sample_indices] batch_target_positions = train_target_positions[sample_indices] batch_target_momentums = train_target_momentums[sample_indices] # Update parameters. grads = compute_updates(state, batch_curr_positions, batch_curr_momentums, time_deltas, batch_target_positions, batch_target_momentums, regularizations) state = state.apply_gradients(grads=grads) # Evaluate, if required. is_last_step = (step == num_train_steps - 1) if step % eval_cadence == (eval_cadence - 1) or is_last_step: train_metrics = compute_metrics_helper( state, train_positions, train_momentums, jump, time_delta, scaler, compute_hamiltonian_fn, train_simulation_parameters, regularizations) log_metrics(step, train_metrics, summary_writer, prefix='train_') all_train_metrics[step] = train_metrics test_metrics = {} for test_jump in test_time_jumps: test_metrics[test_jump] = compute_metrics_helper( state, test_positions, test_momentums, test_jump, time_delta, scaler, compute_hamiltonian_fn, test_simulation_parameters, regularizations) log_metrics(step, test_metrics[test_jump], summary_writer, prefix=f'test_jump_{test_jump}_') all_test_metrics[step] = test_metrics # Save best state seen so far. if train_metrics['total_loss'] < min_train_loss: min_train_loss = train_metrics['total_loss'] best_state = state auxiliary_data = { 'train': { 'positions': train_positions, 'momentums': train_momentums, 'simulation_parameters': train_simulation_parameters, 'metrics': all_train_metrics, }, 'test': { 'positions': test_positions, 'momentums': test_momentums, 'simulation_parameters': test_simulation_parameters, 'metrics': all_test_metrics, }, } return scaler, best_state, auxiliary_data
def main(config, output_dir): seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) return (val_iter, val_steps) val_iter_splits = { 'val': _get_val_split(config.dataset, split=config.val_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) } if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_iter_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_iter_imagenet_real, val_steps = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) val_iter_splits['imagenet_real'] = (val_iter_imagenet_real, val_steps) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_dataset')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = int(ntrain_img / batch_size) if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.get('model', {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['head']['bias'] = jnp.full_like(params['head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['head']['kernel'] = jnp.full_like(params['head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) label_indices = config.get('label_indices') logging.info('!!! mask %s, label_indices %s', mask, label_indices) if label_indices: logits = logits[:, label_indices] # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels[:, :( len(label_indices) if label_indices else config.num_classes)], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): _, outputs = model.apply({'params': flax.core.freeze(params)}, images, train=False) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements rng, train_loop_rngs = jax.random.split(rng) reint_params = ('head/kernel', 'head/bias') if config.get('only_eval', False) or not config.get('reint_head', True): reint_params = [] checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=None, default_reinit_params=reint_params, config=config, ) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_iter_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # NOTE: Validation eval is only run on certain steps, so determine how many # times it was run previously. num_val_runs = sum( map( lambda i: train_utils.itstime(i, config.log_eval_steps, total_steps), range(1, first_step + 1))) for val_name, (val_iter, val_steps) in val_iter_splits.items(): val_iter = itertools.islice(val_iter, num_val_runs * val_steps, None) val_iter_splits[val_name] = (val_iter, val_steps) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): if not config.get('only_eval', False): opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn( opt_repl, lr_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if not config.get('only_eval', False) and train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( train_loop_rngs=train_loop_rngs, optimizer=opt_cpu, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if not config.get('only_eval', False) and train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, (val_iter, val_steps) in val_iter_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. ncorrect, loss, nseen = 0, 0, 0 for _, batch in zip(range(val_steps), val_iter): if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip( probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state( batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[ f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[ f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[ f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # Entries in the ood_ds dict include: # (ind_dataset, ood_dataset1, ood_dataset2, ...). # OOD metrics are computed using ind_dataset paired with each of the # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds # is also included in the ood_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, opt_repl) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning( "%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState(step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer(batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like( previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError( "A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError( "A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats[ "patch_representations"] writer.write_histograms( step, {"patch_representations": patch_representations}) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state
def train(config, model_def, device_batch_size, eval_ds, num_steps, steps_per_epoch, steps_per_eval, train_ds, image_size, data_source, workdir): """Train model.""" make_lr_fn = schedulers.get_make_lr_fn(config) make_temp_fn = schedulers.get_make_temp_fn(config) make_step_size_fn = schedulers.get_make_step_size_fn(config) if jax.host_count() > 1: raise ValueError('CIFAR10 example should not be run on ' 'more than 1 host due to preconditioner updating.') initial_step = 0 # TODO(basv): load from checkpoint. writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs in TensorBoard easier. # with writer.summary_writer.as_default(): writer.write_hparams(dict(config)) rng = random.PRNGKey(config.seed) rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4) base_learning_rate = config.learning_rate # Create the model. model, state = create_model(rng, device_batch_size, image_size, model_def) parameter_overview.log_parameter_overview(model.params) state = jax_utils.replicate(state) train_size = data_source.TRAIN_IMAGES with flax.deprecated.nn.stochastic(init_key): optimizer = create_optimizer(config, model, base_learning_rate, train_size, sampler_rng) del model # Don't keep a copy of the initial model. # Learning rate schedule learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch) temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch) step_size_fn = make_step_size_fn(steps_per_epoch) p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions( config, config.l2_reg, learning_rate_fn, train_size, temperature_fn, step_size_fn) # Create dataset batch iterators. train_iter = iter(train_ds) eval_iter = iter(eval_ds) # Gather metrics. train_metrics = [] epoch = 0 # Ensemble. ensemble = [] ensemble_logits = [] ensemble_labels = [] ensemble_probs = [] def ensemble_add_step(step): if config.lr_schedule == 'cosine': # Add if learning rate jumps up again in the next step. increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8 _, temp_end = ast.literal_eval(config.temp_ramp) past_burn_in = step >= steps_per_epoch * temp_end return increase and past_burn_in elif config.lr_schedule == 'constant': if (step + 1) % steps_per_epoch == 0: return True return False logging.info('Starting training loop at step %d.', initial_step) for step in range(initial_step, num_steps): if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0: optimizer, rng = update_preconditioner(config, optimizer, p_update_grad_vars, rng, state, train_iter) # Generate a PRNG key that will be rolled into the batch step_key = jax.random.fold_in(rng, step) opt_step_rng = jax.random.fold_in(opt_rng, step) # Load and shard the TF batch batch = next(train_iter) batch = input_pipeline.load_and_shard_tf_batch(config, batch) if not config.debug_run: # Shard the step PRNG key # Don't shard the optimizer rng, as it should be equal among all machines. sharded_keys = common_utils.shard_prng_key(step_key) else: sharded_keys = step_key # Train step optimizer, state, metrics = p_train_step(optimizer, state, batch, sharded_keys, opt_step_rng) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) if step == initial_step: initial_train_metrics = get_metrics(config, train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics) train_summary = {'train_' + k: v for k, v in train_summary.items()} logging.log(logging.INFO, 'initial metrics = %s', str(train_summary.items())) if (step + 1) % steps_per_epoch == 0: # We've finished an epoch # Save model params/state. train_metrics = get_metrics(config, train_metrics) # Get training epoch summary for logging train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary = {'train_' + k: v for k, v in train_summary.items()} writer.write_scalars(epoch, train_summary) # Reset train metrics train_metrics = [] # Evaluation if config.do_eval: eval_metrics = [] eval_logits = [] eval_labels = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) # Load and shard the TF batch eval_batch = input_pipeline.load_and_shard_tf_batch( config, eval_batch) # Step logits, labels, metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_logits.append(logits) eval_labels.append(labels) eval_metrics = get_metrics(config, eval_metrics) # Get eval epoch summary for logging eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) eval_summary = {'eval_' + k: v for k, v in eval_summary.items()} writer.write_scalars(epoch, eval_summary) if config.algorithm == 'sgmcmc' and ensemble_add_step(step): ensemble.append((serialization.to_state_dict(optimizer.target), state)) if config.algorithm == 'sgmcmc' and ensemble_add_step( step) and len(ensemble) >= 1: # Gather predictions for this ensemble sample. eval_logits = jnp.concatenate(eval_logits, axis=0) eval_probs = jax.nn.softmax(eval_logits, axis=-1) eval_labels = jnp.concatenate(eval_labels, axis=0) # Ensure that labels are consistent between predict runs. if ensemble_labels: assert jnp.allclose( eval_labels, ensemble_labels[0]), 'Labels unordered between eval runs.' ensemble_logits.append(eval_logits) ensemble_probs.append(eval_probs) ensemble_labels.append(eval_labels) # Compute ensemble predictions over last config.ensemble_size samples. ensemble_last_probs = jnp.mean( jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0) ensemble_metrics = train_functions.compute_metrics_probs( ensemble_last_probs, ensemble_labels[0]) ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics) ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()} ensemble_summary['ensemble_size'] = min(config.ensemble_size, len(ensemble_probs)) writer.write_scalars(epoch, ensemble_summary) epoch += 1 return ensemble, optimizer
def train_and_evaluate(config, workdir, strategy): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. strategy: Distribution strategy to use for distributing the model. """ tf.io.gfile.makedirs(workdir) tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0), 2) tf.random.set_seed(tf_rng.numpy()[0]) # Input pipeline. ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets( config, data_rng, strategy=strategy) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = (ds_info.splits["train"].num_examples // config.global_batch_size * config.num_epochs) steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 256 but scale it according to our batch size. base_learning_rate = config.learning_rate * config.global_batch_size / 256.0 # Initialize model. num_classes = ds_info.features["label"].num_classes if config.distill_teacher: do_distill = True teacher_file_list = (config.distill_teacher).split(",") teacher_models = load_teacher_models(teacher_file_list, num_classes, config, strategy) distill_params = {} distill_params["alpha"] = config.distill_alpha distill_params["beta"] = config.distill_fd_beta distill_params["teacher_model"] = TeacherModel(teacher_models, name="teacher") else: do_distill = False distill_params = None state = create_state(config, num_classes=num_classes, strategy=strategy) ckpt_manager = tf.train.CheckpointManager(checkpoint=state, directory=workdir, max_to_keep=5) if ckpt_manager.latest_checkpoint: state.restore(ckpt_manager.latest_checkpoint) logging.info("Restored from %s", ckpt_manager.latest_checkpoint) else: logging.info("Initializing from scratch.") initial_step = state.global_step.numpy().item() learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) writer = metric_writers.create_default_writer(workdir) writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): state.model.trainable = True # `step` is a Python integer. `global_step` is a TF variable on the # GPU/TPU devices. is_last_step = step == num_train_steps train_step(state, train_iter, config.weight_decay, learning_rate_fn, do_distill, distill_params, strategy) state.train_metrics.update_state_lr( learning_rate_fn(state.global_step.numpy().item())) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) report_progress(step) if step == initial_step: parameter_overview.log_parameter_overview(state.model) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, state.train_metrics.result()) state.train_metrics.reset_states() state.train_metrics.reset_lr() if step % config.eval_every_steps == 0 or is_last_step: state.model.trainable = False if config.dataset == "imagenet-lt": evaluate(state, val_ds, state.val_metrics, strategy) writer.write_scalars(step, state.val_metrics.result()) logging.info("Num val images %d", state.val_metrics.accuracy.count.numpy()) evaluate(state, test_ds, state.test_metrics, strategy) writer.write_scalars(step, state.test_metrics.result()) logging.info("Num test images %d", state.test_metrics.accuracy.count.numpy()) if step % config.checkpoint_every_steps == 0 or is_last_step: checkpoint_path = ckpt_manager.save(step) logging.info("Saved checkpoint %s", checkpoint_path) logging.info("Finishing training at step %d", step) logging.info("Saving the final weights") file_path = "%s/final_weights" % workdir state.model.save_weights(file_path, save_format="tf")
def main(argv): del argv # unused arg config = FLAGS.config # Unpack total and warmup steps # TODO(nband): revert this to separate arguments. total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift) # LR / Optimization Flags batch_size = config.batch_size grad_clip_norm = config.grad_clip_norm weight_decay = config.weight_decay print('Standard wandb hyperparameters:') print({ 'batch_size': batch_size, 'grad_clip_norm': grad_clip_norm, 'weight_decay': weight_decay, 'total_steps': config.total_steps, 'lr': config.lr }) print('SNGP Params:', config.gp_layer) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.seed rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): tf.io.gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d dev per host (%d dev total), that is a %d per-device batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get( dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint: disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:]) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img / batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get('total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) # Specify Gaussian process layer configs. gp_config = config.get('gp_layer', {}) model_dict = vit_utils.initialize_model('sngp', config) model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) variables = model.init(rng, dummy_input, train=False) # Split model parameters into trainable and untrainable collections. states, params = variables.pop('params') del variables # Set bias in the head to a low value, such that loss is small initially. params = flax.core.unfreeze(params) if use_gp_layer: # Modify the head parameter in the GP head. params['head']['output_layer']['bias'] = jnp.full_like( params['head']['output_layer']['bias'], config.get('init_head_bias', 0)) else: params['head']['bias'] = jnp.full_like( params['head']['bias'], config.get('init_head_bias', 0)) return params, states rng, rng_init = jax.random.split(rng) params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, states, images, labels): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply( variable_dict, images, train=False, mean_field_factor=gp_config.get('mean_field_factor', -1.)) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ logits, labels, out['pre_logits']], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,)) def update_fn(opt, states, lr, reset_covmat, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, states, images, labels): # Specify mutable collection to update untrainable GP parameters. variable_dict = {'params': flax.core.freeze(params), **states} model_results, updated_states = model.apply( variable_dict, images, train=True, rngs={'dropout': rng_model_local}, mutable=list(states.keys()), mean_field_factor=gp_config.get('mean_field_factor', -1.)) logits, _ = model_results loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels) return loss, updated_states # Performs exact covariance update (i.e., reset precision matrix resetting # at begining of new epoch) if covmat_momentum is a null value. if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0: # Resets precision matrix to Identity * ridge_penalty if at the begining # of a new epoch. This should be done before accumulate gradient. ridge_penalty = gp_config.get('ridge_penalty', 1.) prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] prec_mat_new = ( (1. - reset_covmat) * prec_mat_old + reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty) states = flax.core.unfreeze(states) states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] = prec_mat_new states = flax.core.freeze(states) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, s), g = vit_utils.accumulate_gradient_with_states( jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat return opt, s, l, rng, measurements # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias', 'head/kernel', 'head/bias') rng, train_loop_rngs = jax.random.split(rng) checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=states_cpu, default_reinit_params=default_reinit_params, config=config) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer states_cpu = checkpoint_data.fixed_model_states accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar( map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) # Prepare the precision matrix resetting schedule, and pre-fetch it to device. reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0) reset_covmat_iter = train_utils.prefetch_scalar( map(reset_covmat_fn, range(first_step, total_steps)), nprefetch=config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) states_repl = flax.jax_utils.replicate(states_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # val_loss = -jnp.inf # results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. # Makes sure log_eval_steps is same as steps_per_epoch. This is because # the precision matrix needs to be updated fully (at the end of each epoch) # when eval takes place. log_eval_steps = steps_per_epoch if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl, reset_covmat_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter, reset_covmat_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, train_loop_rngs, extra_measurements) = update_fn( opt_repl, states_repl, lr_repl, reset_covmat_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout(checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # For GP layer, we will also do the same for untrainable parameters # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, fixed_model_states=states_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update(flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn( opt_repl.target, states_repl, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True ) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results( output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() if wandb_run is not None: wandb_run.finish()
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ # Seed for reproducibility. rng = jax.random.PRNGKey(config.rng_seed) # Set up logging. summary_writer = metric_writers.create_default_writer(workdir) summary_writer.write_hparams(dict(config)) # Get datasets. rng, dataset_rng = jax.random.split(rng) dataset = input_pipeline.get_dataset(config, dataset_rng) graph, labels, masks = jax.tree_map(jnp.asarray, dataset) labels = jax.nn.one_hot(labels, config.num_classes) train_mask = masks['train'] train_indices = jnp.where(train_mask)[0] train_labels = labels[train_indices] num_training_nodes = len(train_indices) # Get subgraphs. if config.differentially_private_training: graph = jax.tree_map(np.asarray, graph) subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to) graph = jax.tree_map(jnp.asarray, graph) # We only need the subgraphs for training nodes. train_subgraphs = subgraphs[train_indices] del subgraphs else: train_subgraphs = None # Initialize privacy accountant. training_privacy_accountant = privacy_accountants.get_training_privacy_accountant( config, num_training_nodes, compute_max_terms_per_node(config)) # Construct and initialize model. rng, init_rng = jax.random.split(rng) estimation_indices = get_estimation_indices(train_indices, config) state = create_train_state(init_rng, config, graph, train_labels, train_subgraphs, estimation_indices) # Set up checkpointing of the model. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Log overview of parameters. parameter_overview.log_parameter_overview(state.params) # Log metrics after initialization. logits = compute_logits(state, graph) metrics_after_init = compute_metrics(logits, labels, masks) metrics_after_init['epsilon'] = 0 log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init') # Train model. rng, train_rng = jax.random.split(rng) max_training_epsilon = get_max_training_epsilon(config) # Hooks called periodically during training. report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_training_steps, writer=summary_writer) profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) hooks = [report_progress, profiler] for step in range(initial_step, config.num_training_steps): # Perform one step of training. with jax.profiler.StepTraceAnnotation('train', step_num=step): # Sample batch. step_rng = jax.random.fold_in(train_rng, step) indices = jax.random.choice(step_rng, num_training_nodes, (config.batch_size, )) # Compute gradients. if config.differentially_private_training: grads = compute_updates_for_dp(state, graph, train_labels, train_subgraphs, indices, config.adjacency_normalization) else: grads = compute_updates(state, graph, train_labels, indices) # Update parameters. state = update_model(state, grads) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 10, step) for hook in hooks: hook(step) # Evaluate, if required. is_last_step = (step == config.num_training_steps - 1) if step % config.evaluate_every_steps == 0 or is_last_step: with report_progress.timed('eval'): # Check if privacy budget exhausted. training_epsilon = training_privacy_accountant(step + 1) if max_training_epsilon is not None and training_epsilon >= max_training_epsilon: break # Compute metrics. logits = compute_logits(state, graph) metrics_during_training = compute_metrics( logits, labels, masks) metrics_during_training['epsilon'] = training_epsilon log_metrics(step, metrics_during_training, summary_writer) # Checkpoint, if required. if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(state) return state
def main(config, output_dir): # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=10) logging.info(config) acquisition_method = config.get('acquisition_method') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) writer.write_hparams(dict(config)) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note(f'Initializing for {acquisition_method}') # Download dataset data_builder = tfds.builder(config.dataset) data_builder.download_and_prepare() seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() val_ds = input_utils.get_data( dataset=config.dataset, split=config.val_split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Only repeat once. ) test_ds = input_utils.get_data( dataset=config.dataset, split=config.test_split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Only repeat once. ) # Init model if config.model_type == 'deterministic': model_utils = deterministic_utils reinit_params = config.get('model_reinit_params', ('head/kernel', 'head/bias')) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.get('model', {})) elif config.model_type == 'batchensemble': model_utils = batchensemble_utils reinit_params = ('batchensemble_head/bias', 'batchensemble_head/kernel', 'batchensemble_head/fast_weight_alpha', 'batchensemble_head/fast_weight_gamma') model = ub.models.PatchTransformerBE(num_classes=config.num_classes, **config.model) else: raise ValueError('Expect config.model_type to be "deterministic" or' f'"batchensemble", but received {config.model_type}.') init = model_utils.create_init(model, config, test_ds) rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) # Load the optimizer from flax. opt_name = config.get('optim_name') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) loaded_params = checkpoint_utils.load_checkpoint(tree=None, path=config.model_init) loaded = checkpoint_utils.restore_from_pretrained_params( params_cpu, loaded_params, config.model.representation_size, config.model.classifier, reinit_params, ) opt_cpu = opt_cpu.replace(target=loaded) # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being # donated otherwise. Ensure opt_cpu is really on the cpu this way. opt_cpu = jax.device_get(opt_cpu) update_fn = model_utils.create_update_fn(model, config) evaluation_fn = model_utils.create_evaluation_fn(model, config) # NOTE: We need this because we need an Id field of type int. # TODO(andreas): Rename to IdSubsetDatasetBuilder? pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder, subset_ids=None) rng, pool_ds_rng = jax.random.split(rng) # NOTE: below line is necessary on multi host setup # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index()) pool_train_ds = input_utils.get_data( dataset=pool_subset_data_builder, split=config.train_split, rng=pool_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, drop_remainder=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, # Don't repeat ) # Potentially acquire an initial training set. initial_training_set_size = config.get('initial_training_set_size', 10) if initial_training_set_size > 0: current_opt_repl = flax_utils.replicate(opt_cpu) pool_ids, _, _, pool_masks = get_ids_logits_masks( model=model, opt_repl=current_opt_repl, ds=pool_train_ds, config=config) rng, initial_uniform_rng = jax.random.split(rng) pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng) initial_training_set_batch_ids, _ = select_acquisition_batch_indices( acquisition_batch_size=initial_training_set_size, scores=pool_scores, ids=pool_ids, ignored_ids=set(), ) else: initial_training_set_batch_ids = [] # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU # then this dataset creation could be simplified. # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340 # CLU is explicitly not accepting outside contributions at the moment. train_subset_data_builder = al_utils.SubsetDatasetBuilder( data_builder, subset_ids=set(initial_training_set_batch_ids)) test_accuracies = [] training_sizes = [] rng, rng_loop = jax.random.split(rng) rngs_loop = flax_utils.replicate(rng_loop) if config.model_type == 'batchensemble': rngs_loop = {'dropout': rngs_loop} # TODO(joost,andreas): double check if below is still necessary # (train_split is independent of this) # NOTE: train_ds_rng is re-used for all train_ds creations rng, train_ds_rng = jax.random.split(rng) measurements = {} accumulated_steps = 0 while True: current_train_ds_length = len(train_subset_data_builder.subset_ids) if current_train_ds_length >= config.get('max_training_set_size', 150): break write_note(f'Training set size: {current_train_ds_length}') current_opt_repl = flax_utils.replicate(opt_cpu) # Only fine-tune if there is anything to fine-tune with. if current_train_ds_length > 0: # Repeat dataset to have oversampled epochs and bootstrap more batches number_of_batches = current_train_ds_length / config.batch_size num_repeats = math.ceil(config.total_steps / number_of_batches) write_note(f'Repeating dataset {num_repeats} times') # We repeat the dataset several times, such that we can obtain batches # of size batch_size, even at start of training. These batches will be # effectively 'bootstrap' sampled, meaning they are sampled with # replacement from the original training set. repeated_train_ds = input_utils.get_data( dataset=train_subset_data_builder, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), # TODO(joost,andreas): double check if below leads to bootstrap # sampling. num_epochs=num_repeats, ) # We use this dataset to evaluate how well we perform on the training set. # We need this to evaluate if we fit well within max_steps budget. train_eval_ds = input_utils.get_data( dataset=train_subset_data_builder, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_eval, available_ops=preprocess_utils.all_ops()), shuffle=False, drop_remainder=False, prefetch_size=config.get('prefetch_to_host', 2), num_epochs=1, ) # NOTE: warmup and decay are not a good fit for the small training set # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps, # **config.get('lr', {}) # ) lr_fn = lambda x: config.lr.base early_stopping_patience = config.get('early_stopping_patience', 15) current_opt_repl, rngs_loop, measurements = finetune( update_fn=update_fn, opt_repl=current_opt_repl, lr_fn=lr_fn, ds=repeated_train_ds, rngs_loop=rngs_loop, total_steps=config.total_steps, train_eval_ds=train_eval_ds, val_ds=val_ds, evaluation_fn=evaluation_fn, early_stopping_patience=early_stopping_patience, profiler=profiler) train_val_accuracies = measurements.pop('train_val_accuracies') current_steps = 0 for step, train_acc, val_acc in train_val_accuracies: writer.write_scalars(accumulated_steps + step, { 'train_accuracy': train_acc, 'val_accuracy': val_acc }) current_steps = step accumulated_steps += current_steps + 10 test_accuracy = get_accuracy(evaluation_fn=evaluation_fn, opt_repl=current_opt_repl, ds=test_ds) write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}') test_accuracies.append(test_accuracy) training_sizes.append(current_train_ds_length) pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks( model=model, opt_repl=current_opt_repl, ds=pool_train_ds, use_pre_logits=acquisition_method == 'density', config=config) if acquisition_method == 'uniform': rng_loop, rng_acq = jax.random.split(rng_loop, 2) pool_scores = get_uniform_scores(pool_masks, rng_acq) elif acquisition_method == 'entropy': pool_scores = get_entropy_scores(pool_outputs, pool_masks) elif acquisition_method == 'margin': pool_scores = get_margin_scores(pool_outputs, pool_masks) elif acquisition_method == 'density': if current_train_ds_length > 0: pool_scores = get_density_scores(model=model, opt_repl=current_opt_repl, train_ds=train_eval_ds, pool_pre_logits=pool_outputs, pool_masks=pool_masks, config=config) else: rng_loop, rng_acq = jax.random.split(rng_loop, 2) pool_scores = get_uniform_scores(pool_masks, rng_acq) else: raise ValueError('Acquisition method not found.') acquisition_batch_ids, _ = select_acquisition_batch_indices( acquisition_batch_size=config.get('acquisition_batch_size', 10), scores=pool_scores, ids=pool_ids, ignored_ids=train_subset_data_builder.subset_ids) train_subset_data_builder.subset_ids.update(acquisition_batch_ids) measurements.update({'test_accuracy': test_accuracy}) writer.write_scalars(current_train_ds_length, measurements) write_note(f'Final acquired training ids: ' f'{train_subset_data_builder.subset_ids}' f'Accuracies: {test_accuracies}') pool.close() pool.join() writer.close() # TODO(joost,andreas): save the final checkpoint return (train_subset_data_builder.subset_ids, test_accuracies)
def main(_): config = FLAGS.config # Unpack total and warmup steps total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names( dist_shift) # LR / Optimization Flags print('wandb hyperparameters:') print({ 'batch_size': config.batch_size, 'grad_clip_norm': config.grad_clip_norm, 'weight_decay': config.weight_decay, 'total_steps': config.total_steps, 'lr': config.lr, 'fast_weight_lr_multiplier': config.fast_weight_lr_multiplier }) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) tf.io.gfile.makedirs(output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get(dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint:disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img // batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') model_dict = vit_utils.initialize_model('batchensemble', config) model = model_dict['model'] ens_size = model_dict['ens_size'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input, train=False))['params'] # Set bias in the head to a low value, such that loss is small initially. params['batchensemble_head']['bias'] = jnp.full_like( params['batchensemble_head']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params['batchensemble_head']['kernel'] = jnp.full_like( params['batchensemble_head']['kernel'], 0) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels): tiled_logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False) loss_name = config.get('loss', 'sigmoid_xent') # TODO(dusenberrymw,zmariet): Clean up and generalize this. if loss_name == 'sigmoid_xent': ens_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_sigmoid_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) else: # softmax ens_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(tiled_logits, ens_size))) pre_logits = batchensemble_utils.log_average_softmax_probs( jnp.asarray(jnp.split(out['pre_logits'], ens_size))) losses = getattr(train_utils, loss_name)(logits=ens_logits, labels=labels[:, :config.num_classes], reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) def batch_loss_fn(params, images, labels, rngs): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs=rngs) labels = jnp.tile(labels, (ens_size, 1)) loss_fn = getattr(train_utils, config.get('loss', 'sigmoid_xent')) loss = jnp.mean(loss_fn(logits=logits, labels=labels)) return loss, dict() @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0, 1)) def update_fn(opt, rngs, lr, images, labels): return batchensemble_utils.update_fn_be( opt=opt, rngs=rngs, lr=lr, images=images, labels=labels, batch_loss_fn=batch_loss_fn, weight_decay_fn=weight_decay_fn, max_grad_norm_global=config.get('grad_clip_norm', None), fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier', None)) # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path reint_params = ('batchensemble_head/bias', 'batchensemble_head/kernel', 'batchensemble_head/fast_weight_alpha', 'batchensemble_head/fast_weight_gamma') if config.get('only_eval', False) or not config.get('reint_head', True): reint_params = [] checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=rng, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=None, default_reinit_params=reint_params, config=config) train_loop_rngs = {'dropout': checkpoint_data.train_loop_rngs} opt_cpu = checkpoint_data.optimizer accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # eval_loss = { # eval_name: -jnp.inf for eval_name, _ in eval_iter_splits.items()} # fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) # TODO(zmariet): Find better way to cut down iteration advancement cost. if not config.get('disable_preemption_reproducibility', False): train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): if not config.get('only_eval', False): opt_repl, train_loop_rngs, extra_measurements = update_fn( opt_repl, train_loop_rngs, lr_repl, train_batch['image'], train_batch['labels']) if jax.process_index() == 0: profiler(step) # Checkpoint saving if not config.get('only_eval', False) and train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( train_loop_rngs=train_loop_rngs, optimizer=opt_cpu, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if not config.get('only_eval', False) and train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Keep to return for reproducibility tests. # train_loss = train_measurements['training_loss'] # Report validation performance if config.get('only_eval', False) or train_utils.itstime( step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation sets...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn(opt_repl.target, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results(output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break if config.get('only_eval', False): break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close()
def maybe_load_checkpoint(train_loop_rngs: jnp.ndarray, save_checkpoint_path: str, init_optimizer: flax.optim.Optimizer, init_params: Params, init_fixed_model_states: Optional[Params], default_reinit_params: Iterable[str], config: ml_collections.ConfigDict) -> CheckpointData: """Loads a model from an existing checkpoint if so indicated by the config. Whether to resume training, initialize from a previous checkpoint, or do nothing is set by the `config` ConfigDict, based on the existence of fields `resume` (resume training) or `model_init` (initialize from pretrained checkpoint). When resuming training, both the model weights and optimizer state (including the training step) are restored. When initializing, only the model parameters are updated. The way in which initializing is prioritized in the following way: 1. Always resume from an existing checkpoint, e.g. resume a finetune job. 2. Resume from a previous checkpoint, e.g. start a cooldown training job. 3. Initialize model from something, e,g, start a fine-tuning job. 4. Do nothing (training from scratch). Args: train_loop_rngs: unreplicated jax.PRNGKey. save_checkpoint_path: File pointing to pretrained checkpoint stored in NumPy `.npz` file. init_optimizer: flax.Optimizer to be updated. init_params: Tree of (possibly randomly) initialized parameters for the model. init_fixed_model_states: Optional pytree of non-trainable parameters. Currently only passed when using SNGP models. default_reinit_params: List of parameter names to reinitialize if not provided by the config file. config: ConfigDict which contains fields indicating if, and how, to load an available checkpoint into the optimizer. If resuming from a previous checkpoint *to start a cooldown job*, the flag `resume` must be set. If initializing a (subset of) model parameters to start a file tuning job, fields `model_init`, `representation_size` and `classifier` must be set. Returns: A CheckpointData instance containing a new rng key, the new optimizer state, the new untrainable parameters (if resuming from a checkpoint), and a dictionary of information about the reloaded state. """ optimizer = init_optimizer fixed_model_states = init_fixed_model_states accum_train_time = 0.0 # TODO(dusenberrymw, zmariet): Directly return an unreplicated rng and the # cumulative training time instead of storing them in `checkpoint_extra`. checkpoint_extra = dict( accum_train_time=accum_train_time, rngs_loop=flax_utils.replicate(train_loop_rngs)) # Parse config file to figure out which setting we are in. resume_from_checkpoint = ( (save_checkpoint_path is not None and tf.io.gfile.exists(save_checkpoint_path)) or config.get("resume") is not None) reinitialize_model = config.get( "model_init") is not None and not resume_from_checkpoint if resume_from_checkpoint: logging.info("Resume training from checkpoint...") # Always prioritize loading from a checkpoint from the current training job. if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path): resume_checkpoint_path = save_checkpoint_path # Otherwise, we reload from a previous checkpoint provided by the config. else: resume_checkpoint_path = config.resume checkpoint_tree = {"opt": init_optimizer, "extra": checkpoint_extra} if init_fixed_model_states is not None: checkpoint_tree["states"] = init_fixed_model_states checkpoint = load_checkpoint(checkpoint_tree, resume_checkpoint_path) optimizer, checkpoint_extra = checkpoint["opt"], checkpoint["extra"] fixed_model_states = checkpoint.get("states", None) elif reinitialize_model: logging.info("Initialize model...") reinit_params = config.get("model_reinit_params", default_reinit_params) logging.info("Reinitializing these parameters: %s", reinit_params) loader = lambda path: load_checkpoint(tree=None, path=path) loaded_params = loader(config.model_init) loaded_params = restore_from_pretrained_params( init_params=init_params, loaded_params=loaded_params, model_representation_size=config.model.representation_size, model_classifier=config.model.classifier, reinit_params=reinit_params) optimizer = init_optimizer.replace(target=loaded_params) if jax.process_index() == 0: logging.info("Restored parameter overview:") parameter_overview.log_parameter_overview(loaded_params) else: logging.info("No checkpoint to recover from; using default initialization.") return CheckpointData( optimizer=optimizer, fixed_model_states=fixed_model_states, train_loop_rngs=checkpoint_extra["rngs_loop"], accumulated_train_time=checkpoint_extra["accum_train_time"])
def main(argv): del argv config = FLAGS.config workdir = FLAGS.workdir logging.info("Workdir: %s", workdir) save_checkpoint_path = None if config.get("checkpoint_steps"): tf.io.gfile.makedirs(workdir) save_checkpoint_path = os.path.join(workdir, "checkpoint.npz") # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() # This seed makes the Jax part of things (like model init) deterministic. # However, full training still won't be deterministic, for example due to the # tf.data pipeline not being deterministic even if we would set TF seed. rng = jax.random.PRNGKey(config.get("seed", 0)) def write_note(note): if jax.host_id() == 0: logging.info("NOTE: %s", note) write_note("Initializing...") # Verify settings to make sure no checkpoints are accidentally missed. if config.get("keep_checkpoint_steps"): assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`." assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be" f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`") batch_size = config.batch_size batch_size_eval = config.get("batch_size_eval", batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must " f"be divisible by device number ({jax.device_count()})") local_batch_size = batch_size // jax.host_count() local_batch_size_eval = batch_size_eval // jax.host_count() logging.info( "Global batch size %d on %d hosts results in %d local batch size. " "With %d dev per host (%d dev total), that's a %d per-device batch size.", batch_size, jax.host_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note("Initializing train dataset...") train_ds = input_pipeline.get_data( dataset=config.dataset, split=config.train_split, data_dir=fillin(config.get("dataset_dir")), batch_size=local_batch_size, preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train), shuffle_buffer_size=config.shuffle_buffer_size, prefetch=config.get("prefetch_to_host", 2), cache=False) # Start prefetching already. train_iter = u.start_input_pipeline( train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size) # We always pad to local_batch_size_eval even when less would be enough in # order to minimize memory fragmentation. write_note("Initializing val dataset(s)...") def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_pipeline.get_num_examples( dataset, split, data_dir=fillin(data_dir)) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info("Running validation for %d steps for %s, %s", val_steps, dataset, split) val_it = input_pipeline.get_data( dataset=dataset, split=split, data_dir=fillin(data_dir), batch_size=local_batch_size_eval, preprocess_fn=pp_builder.get_preprocess_fn(pp_eval), cache=config.get("val_cache", "batched"), repeat_after_batching=True, prefetch=0, # Save memory since we cache. drop_remainder=False, shuffle_files=False) val_it = u.start_input_pipeline( val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval) return (val_it, val_steps) if isinstance(config.val_split, str): val_ds = {"val": _get_val_split(config.dataset, config.val_split, config.pp_eval, config.get("dataset_dir"))} else: val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split} ntrain_img = input_pipeline.get_num_examples( config.dataset, config.train_split, data_dir=fillin(config.get("dataset_dir"))) steps_per_epoch = ntrain_img / batch_size if config.get("num_epochs"): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get("total_steps"), "Set either num_epochs or total_steps" else: total_steps = config.total_steps logging.info( "Running for %d steps, that means %f epochs and %f steps per epoch", total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch) write_note(f"Initializing {config.model_name} model...") model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}") model = model_mod.Model( num_classes=config.num_classes, **config.get("model", {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend="cpu") def init(rng): image_size = tuple(train_ds.element_spec["image"].shape[1:]) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] # Set bias in the head to a low value, such that loss is small initially. params["head"]["bias"] = jnp.full_like( params["head"]["bias"], config.get("init_head_bias", 0)) return params rng, rng_init = jax.random.split(rng) params_cpu = init(rng_init) if jax.host_id() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) mw.measure("num_params", num_params) @partial(jax.pmap, axis_name="batch") def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, _ = model.apply({"params": flax.core.freeze(params)}, images) losses = getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses * mask, axis_name="batch") top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch") n = jax.lax.psum(mask, axis_name="batch") return ncorrect, loss, n # Setup function for computing representation. @partial(jax.pmap, axis_name="batch") def representation_fn(params, images, labels, mask): _, outputs = model.apply({"params": flax.core.freeze(params)}, images) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, "batch") labels = jax.lax.all_gather(labels, "batch") mask = jax.lax.all_gather(mask, "batch") return representation, labels, mask # Load the optimizer either from our folder or from flax. opt_name = config.get("optim_name", "momentum_hp") write_note(f"Initializing {opt_name} optimizer...") try: opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}") opt_def = opt_mod.Optimizer(**config.get("optim", {})) except ModuleNotFoundError: opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) @partial(jax.pmap, axis_name="batch", donate_argnums=(0,)) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} if config.get("mixup") and config.mixup.p: rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) def loss_fn(params, images, labels): logits, _ = model.apply( {"params": flax.core.freeze(params)}, images, train=True, rngs={"dropout": rng_model_local}) return getattr(u, config.get("loss", "sigmoid_xent"))( logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get("grad_accum_steps")) l, g = jax.lax.pmean((l, g), axis_name="batch") # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements["l2_grads"] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get("grad_clip_norm"): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get("weight_decay", []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [(".*kernel.*", decay_rules)] sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=u.tree_map_with_regex( decay_fn, opt.target, decay_rules, name="weight decay")) params, _ = jax.tree_flatten(opt.target) measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements # Other things besides optimizer state to be stored. checkpoint_extra = dict(accum_train_time=0.0) # Decide how to initialize training. The order is important. # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job. # 2. Resume from a previous checkpoint, e.g. start a cooldown training job. # 3. Initialize model from something, e,g, start a fine-tuning job. # 4. Train from scratch. resume_checkpoint_path = None if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path): resume_checkpoint_path = save_checkpoint_path elif config.get("resume"): resume_checkpoint_path = fillin(config.resume) if resume_checkpoint_path: write_note("Resume training from checkpoint...") checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra} _, checkpoint_tree = jax.tree_flatten(checkpoint) loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path) # bfloat16 type gets lost when data is saved to disk, so we recover it. checkpoint = jax.tree_map(u.recover_dtype, loaded) opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"] elif config.get("model_init"): write_note(f"Initialize model from {config.model_init}...") loaded = model_mod.load(params_cpu, config.model_init, config.get("model")) opt_cpu = opt_cpu.replace(target=loaded) if jax.host_id() == 0: logging.info("Restored parameter overview:") parameter_overview.log_parameter_overview(loaded) write_note("Kicking off misc stuff...") first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. chrono = u.Chrono(first_step, total_steps, batch_size, checkpoint_extra["accum_train_time"]) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=workdir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = u.create_learning_rate_schedule( batch_size, total_steps, steps_per_epoch, **config.get("lr", {})) lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)), config.get("prefetch_to_device", 1)) write_note(f"Replicating...\n{chrono.note}") opt_repl = flax_utils.replicate(opt_cpu) write_note(f"Initializing few-shotters...\n{chrono.note}") if "fewshot" in config: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get("batch_size") or batch_size_eval) rng, rng_loop = jax.random.split(rng, 2) rngs_loop = flax_utils.replicate(rng_loop) checkpoint_writer = None write_note(f"First step compilations...\n{chrono.note}") # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): mw.step_start(step) with jax.profiler.TraceContext("train_step", step_num=step, _r=1): opt_repl, loss_value, rngs_loop, extra_measurements = update_fn( opt_repl, lr_repl, train_batch["image"], train_batch["labels"], rng=rngs_loop) if jax.host_id() == 0: profiler(step) # Checkpoint saving if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0): chrono.pause() u.checkpointing_timeout(checkpoint_writer, config.get("checkpoint_timeout", 1)) checkpoint_extra["accum_train_time"] = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps): copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra} checkpoint_writer = pool.apply_async( u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if u.itstime(step, config.log_training_steps, total_steps, host=0): mw.measure("learning_rate", lr_repl[0]) mw.measure("training_loss", loss_value[0]) for name, value in extra_measurements.items(): mw.measure(name, value[0]) chrono.tick(step, mw.measure, write_note) # Report validation performance if u.itstime(step, config.log_eval_steps, total_steps): chrono.pause() for val_name, (val_iter, val_steps) in val_ds.items(): ncorrect, loss, nseen = 0, 0, 0 for _, batch in zip(range(val_steps), val_iter): batch_ncorrect, batch_losses, batch_n = evaluation_fn( opt_repl.target, batch["image"], batch["labels"], batch["mask"]) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) mw.measure(f"{val_name}_prec@1", ncorrect / nseen) mw.measure(f"{val_name}_loss", loss / nseen) chrono.resume() if "fewshot" in config: # Compute few-shot on-the-fly evaluation. if u.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f"Few-shot evaluation...\n{chrono.note}") r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets) fewshotter.walk_results(mw.measure, *r) chrono.resume() mw.step_end() write_note(f"Done!\n{chrono.note}") pool.close() pool.join() mw.close()