Exemple #1
0
def create_train_state(
    config, rng, learning_rate_fn,
    example_batch):
  """Create and initialize the model.

  Args:
    config: Configuration for model.
    rng: JAX PRNG Key.
    learning_rate_fn: learning rate function
    example_batch: for model intialization

  Returns:
    The initialized TrainState with the optimizer.
  """
  model, variables = create_model(config, rng, example_batch)
  params = variables['params']
  parameter_overview.log_parameter_overview(params)

  state = train_state.TrainState.create(
      apply_fn=model.apply,
      params=variables['params'],
      tx=optax.adamw(
          learning_rate=learning_rate_fn,
          b1=0.9,
          b2=.98,
          eps=1e-9,
          weight_decay=config.train.weight_decay),
  )
  return model, state
Exemple #2
0
 def test_count_parameters_on_module_with_duplicate_names(self):
   module = snt.Module()
   # Weights of a 2D convolution with 2 filters..
   module.conv = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv")
   module.conv(tf.ones((2, 5, 5, 3)))  # 3 * 3*3 * 2 + 2 (bias) = 56 parameters
   module.conv2 = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv")
   module.conv2(tf.ones(
       (2, 5, 5, 3)))  # 3 * 3*3 * 2 + 2 (bias) = 56 parameters
   parameter_overview.log_parameter_overview(module)
   self.assertEqual(112, parameter_overview.count_parameters(module))
Exemple #3
0
    def __init__(self, mode, config):
        """Initializes experiment."""

        self.mode = mode
        self.config = config
        self.vdvae = vdvae.Vdvae(config)
        self.rng = jax.random.PRNGKey(config.seed)

        # setup eval
        _, self._eval_ds = dataset.create_eval_dataset(
            config.data.task, config.evaluation.batch_size,
            config.evaluation.subset)
        self.rng, eval_rng = jax.random.split(self.rng)
        self._eval_batch = functools.partial(self._eval_batch,
                                             base_rng=eval_rng)
        self._eval_batch = jax.pmap(self._eval_batch, axis_name='batch')

        if mode == 'train':
            self.rng, data_rng = jax.random.split(self.rng)
            data_rng = jax.random.fold_in(data_rng, jax.process_index())
            _, train_ds = dataset.create_train_dataset(
                config.data.task, config.training.batch_size,
                config.training.substeps, data_rng)
            self._train_iter = iter(train_ds)

            self.rng, init_rng, sample_rng = jax.random.split(self.rng, num=3)
            input_shape = tuple(train_ds.element_spec.shape[2:])
            params = self.vdvae.init(init_rng, sample_rng, input_shape[0],
                                     jnp.ones(input_shape, dtype=jnp.uint8))
            parameter_overview.log_parameter_overview(params)
            # Use the same rng to init with the same params.
            ema_params = jax.tree_map(jnp.array, params)

            opt_init, _ = self.optimizer(
                learning_rate=self.config.optimizer.base_learning_rate)
            opt_state = opt_init(params)

            # create train state
            self._train_state = TrainState(step=0,
                                           params=params,
                                           ema_params=ema_params,
                                           opt_state=opt_state)

            self.rng, update_rng = jax.random.split(self.rng)
            self._update_func = functools.partial(self._update_func,
                                                  update_rng)
            self._update_func = functools.partial(jax.lax.scan,
                                                  self._update_func)
            self._update_func = jax.pmap(self._update_func, axis_name='batch')
def create_train_state(model, config, rng, inputs):
    """Create and initialize the model.

  Args:
    model: Flax nn module.
    config: Configuration for model.
    rng: JAX PRNG Key.
    inputs: The init inputs fed into the model.

  Returns:
    The initialized TrainState with the optimizer.
  """
    rng, params_rng, dropout_rng = jax.random.split(rng, 3)
    variables = model.init({
        'params': params_rng,
        'dropout': dropout_rng
    }, inputs)
    params = variables['params']
    parameter_overview.log_parameter_overview(params)  # pytype: disable=wrong-arg-types

    if config.optimizer == 'AdamW':

        def w_decay_fn(path):
            return all(pn not in path for pn in config.no_weight_decay)

        def wo_decay_fn(path):
            return any(pn in path for pn in config.no_weight_decay)

        optimizer_w_decay = flax.optim.Adam(
            learning_rate=config.learning_rate,
            weight_decay=config.learning_rate_weight_decay)
        optimizer_wo_decay = flax.optim.Adam(
            learning_rate=config.learning_rate, weight_decay=0)
        params_w_decay = flax.optim.ModelParamTraversal(
            lambda path, _: w_decay_fn(path))
        params_wo_decay = flax.optim.ModelParamTraversal(
            lambda path, _: wo_decay_fn(path))
        optimizer = flax.optim.MultiOptimizer(
            (params_w_decay, optimizer_w_decay),
            (params_wo_decay, optimizer_wo_decay)).create(params)
    else:
        raise NotImplementedError
    return TrainState(step=0, optimizer=optimizer)
Exemple #5
0
def create_train_state(config, rng, input_shape, num_classes):
    """Create and initialize the model.

  Args:
    config: Configuration for model.
    rng: JAX PRNG Key.
    input_shape: Shape of the inputs fed into the model.
    num_classes: Number of classes in the output layer.

  Returns:
    The model and initialized TrainState with the optimizer.
  """
    # Use of getattr could simplify this but is discouraged by the style guide. We
    # should consider using it anyway if sequence of ifs grows too big.
    if config.model_name == "tiny_classifier":
        get_model = models.tiny_classifier
    elif config.model_name == "spin_classifier_6_layers":
        get_model = models.spin_classifier_6_layers
    elif config.model_name == "spherical_classifier_6_layers":
        get_model = models.spherical_classifier_6_layers
    elif config.model_name == "cnn_classifier_6_layers":
        get_model = models.cnn_classifier_6_layers
    else:
        raise ValueError(f"Model {config.model_name} not supported.")
    model = get_model(num_classes, axis_name=_PMAP_AXIS_NAME)

    variables = model.init(rng, jnp.ones(input_shape), train=False)
    params = variables["params"]
    batch_stats = variables.get("batch_stats", {})

    abs_if_complex = lambda x: jnp.abs(x) if x.dtype == jnp.complex64 else x
    parameter_overview.log_parameter_overview(
        jax.tree_util.tree_map(abs_if_complex, params))
    optimizer = flax.optim.Adam().create(params)
    return model, TrainState(step=0,
                             optimizer=optimizer,
                             batch_stats=batch_stats)
Exemple #6
0
def train_and_evaluate(config, workdir):
    """Performs training and evaluation with the given configuration."""
    # Set up logging.
    summary_writer = metric_writers.create_default_writer(workdir)
    summary_writer.write_hparams(config.to_dict())

    time_delta = config.time_delta
    train_time_jump_range = config.train_time_jump_range
    test_time_jumps = config.test_time_jumps
    num_trajectories = config.num_trajectories
    num_samples = config.num_samples
    train_split_proportion = config.train_split_proportion
    num_train_steps = config.num_train_steps
    regularizations = config.regularizations.to_dict()
    eval_cadence = config.eval_cadence

    # Get simulation functions.
    generate_canonical_coordinates_fn = get_generate_canonical_coordinates_fn(
        config)
    compute_hamiltonian_fn = get_compute_hamiltonian_fn(config)

    # Generate data.
    rng = jax.random.PRNGKey(config.rng_seed)
    rng, simulation_parameters_rng = jax.random.split(rng)
    simulation_parameters = sample_simulation_parameters(
        config.simulation_parameter_ranges.to_dict(), num_trajectories,
        simulation_parameters_rng)
    times = jnp.arange(num_samples) * time_delta
    all_positions, all_momentums = generate_canonical_coordinates_fn(
        times, simulation_parameters)

    # Train-test split.
    if config.split_on == 'times':
        num_train_samples = int(num_samples * train_split_proportion)
        train_positions = all_positions[:num_train_samples]
        test_positions = all_positions[num_train_samples:]
        train_momentums = all_momentums[:num_train_samples]
        test_momentums = all_momentums[num_train_samples:]

        train_simulation_parameters = simulation_parameters
        test_simulation_parameters = simulation_parameters
    else:
        raise ValueError(f'Unsupported feature for split: {config.split_on}.')

    # Rescale.
    scaler = create_scaler(config)
    scaler = fit_scaler(train_positions, train_momentums, scaler)
    train_positions, train_momentums = transform_with_scaler(
        train_positions, train_momentums, scaler)
    test_positions, test_momentums = transform_with_scaler(
        test_positions, test_momentums, scaler)

    # Initialize model.
    state = create_train_state(
        config, rng, (train_positions[:1], train_momentums[:1], time_delta))
    best_state = state
    parameter_overview.log_parameter_overview(state.params)

    # Setup for coordinates and time deltas.
    coordinates_fn = get_coordinates_fn(config)
    time_deltas_fn = get_time_deltas_fn(config)

    # Setup sampling for time jumps.
    sample_time_jump_fn = functools.partial(
        sample_time_jump_with_linear_increase,
        min_jump=train_time_jump_range[0],
        max_jump=train_time_jump_range[1],
        num_train_steps=num_train_steps)

    min_train_loss = jnp.inf
    all_train_metrics = {}
    all_test_metrics = {}

    for step in range(num_train_steps):
        step_rng = jax.random.fold_in(rng, step)

        # Sample time jump.
        step_rng, jump_rng = jax.random.split(step_rng)
        jump = sample_time_jump_fn(step, rng=jump_rng)

        # Setup inputs and targets on all trajectories.
        train_curr_positions, train_curr_momentums, train_target_positions, train_target_momentums = coordinates_fn(
            train_positions, train_momentums, jump)
        time_deltas = time_deltas_fn(jump)

        # Sample indices.
        num_samples_on_trajectory = train_curr_positions.shape[0]
        sample_indices = jax.random.choice(step_rng, num_samples_on_trajectory,
                                           (config.batch_size, ))
        batch_curr_positions = train_curr_positions[sample_indices]
        batch_curr_momentums = train_curr_momentums[sample_indices]
        batch_target_positions = train_target_positions[sample_indices]
        batch_target_momentums = train_target_momentums[sample_indices]

        # Update parameters.
        grads = compute_updates(state, batch_curr_positions,
                                batch_curr_momentums, time_deltas,
                                batch_target_positions, batch_target_momentums,
                                regularizations)
        state = state.apply_gradients(grads=grads)

        # Evaluate, if required.
        is_last_step = (step == num_train_steps - 1)
        if step % eval_cadence == (eval_cadence - 1) or is_last_step:
            train_metrics = compute_metrics_helper(
                state, train_positions, train_momentums, jump, time_delta,
                scaler, compute_hamiltonian_fn, train_simulation_parameters,
                regularizations)
            log_metrics(step, train_metrics, summary_writer, prefix='train_')
            all_train_metrics[step] = train_metrics

            test_metrics = {}
            for test_jump in test_time_jumps:
                test_metrics[test_jump] = compute_metrics_helper(
                    state, test_positions, test_momentums, test_jump,
                    time_delta, scaler, compute_hamiltonian_fn,
                    test_simulation_parameters, regularizations)
                log_metrics(step,
                            test_metrics[test_jump],
                            summary_writer,
                            prefix=f'test_jump_{test_jump}_')
            all_test_metrics[step] = test_metrics

            # Save best state seen so far.
            if train_metrics['total_loss'] < min_train_loss:
                min_train_loss = train_metrics['total_loss']
                best_state = state

    auxiliary_data = {
        'train': {
            'positions': train_positions,
            'momentums': train_momentums,
            'simulation_parameters': train_simulation_parameters,
            'metrics': all_train_metrics,
        },
        'test': {
            'positions': test_positions,
            'momentums': test_momentums,
            'simulation_parameters': test_simulation_parameters,
            'metrics': all_test_metrics,
        },
    }
    return scaler, best_state, auxiliary_data
def main(config, output_dir):

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)
        val_iter = input_utils.start_input_pipeline(
            val_ds, config.get('prefetch_to_device', 1))

        return (val_iter, val_steps)

    val_iter_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_iter_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_iter_imagenet_real, val_steps = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
        val_iter_splits['imagenet_real'] = (val_iter_imagenet_real, val_steps)

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s', config.get('ood_dataset'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))
    model = ub.models.vision_transformer(num_classes=config.num_classes,
                                         **config.get('model', {}))

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['head']['bias'] = jnp.full_like(params['head']['bias'],
                                               config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['head']['kernel'] = jnp.full_like(params['head']['kernel'],
                                                     0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            logits = logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        if label_indices:
            logits = logits[:, label_indices]

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        _, outputs = model.apply({'params': flax.core.freeze(params)},
                                 images,
                                 train=False)
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements

    rng, train_loop_rngs = jax.random.split(rng)
    reint_params = ('head/kernel', 'head/bias')
    if config.get('only_eval', False) or not config.get('reint_head', True):
        reint_params = []
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=train_loop_rngs,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=reint_params,
        config=config,
    )
    train_loop_rngs = checkpoint_data.train_loop_rngs
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_iter_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)
        # NOTE: Validation eval is only run on certain steps, so determine how many
        # times it was run previously.
        num_val_runs = sum(
            map(
                lambda i: train_utils.itstime(i, config.log_eval_steps,
                                              total_steps),
                range(1, first_step + 1)))
        for val_name, (val_iter, val_steps) in val_iter_splits.items():
            val_iter = itertools.islice(val_iter, num_val_runs * val_steps,
                                        None)
            val_iter_splits[val_name] = (val_iter, val_steps)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl,
                    lr_repl,
                    train_batch['image'],
                    train_batch['labels'],
                    rng=train_loop_rngs)

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                train_loop_rngs=train_loop_rngs,
                optimizer=opt_cpu,
                accumulated_train_time=accumulated_train_time)

            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, (val_iter, val_steps) in val_iter_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                ncorrect, loss, nseen = 0, 0, 0
                for _, batch in zip(range(val_steps), val_iter):
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, batch['image'],
                                          batch['labels'], batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            # TODO(jereliu): Extend to support soft multi-class probabilities.
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name],
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # Entries in the ood_ds dict include:
            # (ind_dataset, ood_dataset1, ood_dataset2, ...).
            # OOD metrics are computed using ind_dataset paired with each of the
            # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
            # is also included in the ood_ds.
            if ood_ds and config.ood_methods:
                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds, ood_ds_names, config.ood_methods, evaluation_fn,
                    opt_repl)
                writer.write_scalars(step, ood_measurements)
            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target, config.fewshot.datasets)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
    """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
    rng, model_rng = jax.random.split(rng)
    input_shape = tuple(train_ds.element_spec["image"].shape[1:])
    model, init_params, init_state = create_model(module, input_shape,
                                                  model_rng)
    parameter_overview.log_parameter_overview(model.params)

    # Load a pretrained model parameters and state. Ignore the step and the
    # optimizer state in the checkpoint.
    pretrained_path = config.get("pretrained_checkpoint", "")
    if pretrained_path:
        logging.info("Load pretrained weights from '%s'", pretrained_path)
        state_dict = checkpoint.load_state_dict(pretrained_path)
        flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                                  sep="/")
        model_state = state_dict["model_state"]

        # A prefix can be used to replace only a subpart of the network (e.g the
        # encoder). Prepend the prefix (if any) to model parameters and states.
        prefix = config.get("pretrained_prefix", "")
        if prefix:
            flatten_model_params = utils.add_prefix_to_dict_keys(
                flatten_model_params, f"{prefix}/")
            model_state = utils.add_prefix_to_dict_keys(
                model_state, f"/{prefix}")

        # Merge the params/state from the checkpoint into the initial params/state.
        flatten_init_params = utils.flatten_dict(init_params, sep="/")
        flatten_init_params, ignored_params = utils.override_dict(
            flatten_init_params, flatten_model_params)
        init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
        init_state, _ = utils.override_dict(init_state, model_state)

        if ignored_params:
            logging.warning(
                "%d/%d parameters from the pretrained checkpoint "
                "were ignored: %s", len(ignored_params),
                len(flatten_init_params), ignored_params)

    optimizer_state = optimizer.init(init_params)

    state = TrainState(step=1,
                       model_params=init_params,
                       model_state=init_state,
                       optimizer_state=optimizer_state)  # type: ignore
    # Do not keep a copy of the initial model.
    del init_params, init_state, optimizer_state

    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    checkpoint_dir = os.path.join(workdir, "checkpoints")

    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step)
    # Replicate our parameters.
    state = flax.jax_utils.replicate(state)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    step_timer = utils.StepTimer(batch_size=config.batch_size,
                                 initial_step=initial_step)

    # Write config to the summary files. This makes the hyperparameters available
    # in TensorBoard and makes comparison of runs with tensorboard/ easier.
    if initial_step == 1:
        writer.write_hparams(utils.flatten_dict(config.to_dict()))

    # Generate per-device PRNG keys for the training loop.
    rng, train_rng = jax.random.split(rng)
    train_rngs = jax.random.split(train_rng, jax.local_device_count())

    # Generate per-device PRNG keys for model evaluation.
    rng, eval_rng = jax.random.split(rng)
    eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    train_metrics = utils.Means()

    do_eval_only = config.get("do_eval_only", False)
    if do_eval_only:
        config.num_train_steps = 1

    debug_enabled = config.get("debug", False)
    previous_grads = grads = None
    previous_updates = updates = None
    previous_state = None
    for step in range(initial_step, config.num_train_steps + 1):
        is_last_step = step == config.num_train_steps
        if debug_enabled:
            previous_grads = grads
            previous_updates = updates
            previous_state = state

        # Skip the training if only do the eval.
        if not do_eval_only:
            # Use ._numpy() to avoid copy.
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            state, grads, updates, metrics, training_stats, train_rngs = train_step(
                state, batch, module, loss_fn, optimizer, train_metrics_dict,
                train_rngs)
            train_metrics.append(flax.jax_utils.unreplicate(metrics))

            # Update topk temperature with linearly decreasing schedule if enabled.
            if (config.get("linear_decrease_perturbed_sigma", False) and
                    config.get("selection_method", "") == "perturbed-topk"):

                model_state = state.model_state.as_dict()

                if "/PatchNet_0" in model_state:
                    net_str = "/PatchNet_0"
                else:
                    net_str = "/"

                progress = step / config.num_train_steps
                sigma_multiplier = 1. - progress
                previous_mult = model_state[net_str]["sigma_mutiplier"]
                sigma_multiplier = sigma_multiplier + jnp.zeros_like(
                    previous_mult)
                model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
                state = state.replace(model_state=nn.Collection(model_state))

            if debug_enabled:
                if utils.has_any_inf_or_nan(metrics):
                    # Save checkpoint
                    if previous_state:
                        ckpt.save(flax.jax_utils.unreplicate(previous_state))
                    ckpt.save(flax.jax_utils.unreplicate(state))

                    # Log gradients and updates.
                    if previous_grads or previous_updates:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=previous_grads,
                                                 updates=previous_updates)
                    write_gradient_histogram(writer,
                                             step + 1,
                                             grads=grads,
                                             updates=updates)

                    raise RuntimeError(
                        "A training metric took an invalid value: "
                        f"{metrics}.")

            logging.log_first_n(logging.INFO, "Finished training step %d.", 3,
                                step)
            report_progress(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                results = train_metrics.result()
                writer.write_scalars(step, results)
                writer.write_scalars(step, step_timer.get_and_reset(step))
                if utils.has_any_inf_or_nan(results):
                    raise ValueError(
                        "A training metric took an invalid value.")
                train_metrics.reset()

        if (step % config.checkpoint_every_steps == 0 or is_last_step):
            with step_timer.paused():
                ckpt.save(flax.jax_utils.unreplicate(state))

        # Evaluation
        if step % config.eval_every_steps == 0 or is_last_step:
            with step_timer.paused():
                eval_metrics, first_batch_stats, eval_rngs = evaluate(
                    state, module, eval_ds, eval_metrics_dict, eval_rngs)

            if jax.host_id() == 0:
                log_histograms = config.get("log_histograms", False)
                log_images = config.get("log_images", True)
                # Log the last gradients and updates histograms.
                if not do_eval_only:
                    write_stats_results(writer,
                                        step,
                                        training_stats,
                                        stats_aggregators,
                                        prefix="train/",
                                        log_images=log_images)
                    if log_histograms:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=grads,
                                                 updates=updates)

                write_stats_results(writer,
                                    step,
                                    first_batch_stats,
                                    stats_aggregators,
                                    prefix="eval/",
                                    log_images=log_images)

                # write patch representation histograms
                if (log_histograms and first_batch_stats
                        and "patch_representations" in first_batch_stats):
                    patch_representations = first_batch_stats[
                        "patch_representations"]
                    writer.write_histograms(
                        step, {"patch_representations": patch_representations})

                if eval_metrics:
                    writer.write_scalars(step, eval_metrics)

    writer.flush()
    return state
def train(config, model_def, device_batch_size, eval_ds, num_steps,
          steps_per_epoch, steps_per_eval, train_ds, image_size, data_source,
          workdir):
  """Train model."""

  make_lr_fn = schedulers.get_make_lr_fn(config)
  make_temp_fn = schedulers.get_make_temp_fn(config)
  make_step_size_fn = schedulers.get_make_step_size_fn(config)
  if jax.host_count() > 1:
    raise ValueError('CIFAR10 example should not be run on '
                     'more than 1 host due to preconditioner updating.')

  initial_step = 0  # TODO(basv): load from checkpoint.
  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)

  # Write config to the summary files. This makes the hyperparameters available
  # in TensorBoard and makes comparison of runs in TensorBoard easier.
  # with writer.summary_writer.as_default():
  writer.write_hparams(dict(config))

  rng = random.PRNGKey(config.seed)
  rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4)

  base_learning_rate = config.learning_rate

  # Create the model.
  model, state = create_model(rng, device_batch_size, image_size, model_def)
  parameter_overview.log_parameter_overview(model.params)
  state = jax_utils.replicate(state)

  train_size = data_source.TRAIN_IMAGES

  with flax.deprecated.nn.stochastic(init_key):
    optimizer = create_optimizer(config, model, base_learning_rate, train_size,
                                 sampler_rng)
  del model  # Don't keep a copy of the initial model.

  # Learning rate schedule
  learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch)
  temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch)
  step_size_fn = make_step_size_fn(steps_per_epoch)

  p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions(
      config, config.l2_reg, learning_rate_fn, train_size, temperature_fn,
      step_size_fn)

  # Create dataset batch iterators.
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Gather metrics.
  train_metrics = []
  epoch = 0

  # Ensemble.
  ensemble = []
  ensemble_logits = []
  ensemble_labels = []
  ensemble_probs = []

  def ensemble_add_step(step):
    if config.lr_schedule == 'cosine':
      # Add if learning rate jumps up again in the next step.
      increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8
      _, temp_end = ast.literal_eval(config.temp_ramp)
      past_burn_in = step >= steps_per_epoch * temp_end
      return increase and past_burn_in

    elif config.lr_schedule == 'constant':
      if (step + 1) % steps_per_epoch == 0:
        return True
    return False

  logging.info('Starting training loop at step %d.', initial_step)

  for step in range(initial_step, num_steps):
    if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0:
      optimizer, rng = update_preconditioner(config, optimizer,
                                             p_update_grad_vars, rng, state,
                                             train_iter)
    # Generate a PRNG key that will be rolled into the batch
    step_key = jax.random.fold_in(rng, step)
    opt_step_rng = jax.random.fold_in(opt_rng, step)

    # Load and shard the TF batch
    batch = next(train_iter)
    batch = input_pipeline.load_and_shard_tf_batch(config, batch)
    if not config.debug_run:
      # Shard the step PRNG key
      # Don't shard the optimizer rng, as it should be equal among all machines.
      sharded_keys = common_utils.shard_prng_key(step_key)
    else:
      sharded_keys = step_key

    # Train step
    optimizer, state, metrics = p_train_step(optimizer, state, batch,
                                             sharded_keys, opt_step_rng)
    train_metrics.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
    if step == initial_step:
      initial_train_metrics = get_metrics(config, train_metrics)
      train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics)
      train_summary = {'train_' + k: v for k, v in train_summary.items()}
      logging.log(logging.INFO, 'initial metrics = %s',
                  str(train_summary.items()))

    if (step + 1) % steps_per_epoch == 0:
      # We've finished an epoch
      # Save model params/state.

      train_metrics = get_metrics(config, train_metrics)
      # Get training epoch summary for logging
      train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)

      train_summary = {'train_' + k: v for k, v in train_summary.items()}

      writer.write_scalars(epoch, train_summary)
      # Reset train metrics
      train_metrics = []

      # Evaluation
      if config.do_eval:
        eval_metrics = []
        eval_logits = []
        eval_labels = []
        for _ in range(steps_per_eval):
          eval_batch = next(eval_iter)
          # Load and shard the TF batch
          eval_batch = input_pipeline.load_and_shard_tf_batch(
              config, eval_batch)
          # Step
          logits, labels, metrics = p_eval_step(optimizer.target, state,
                                                eval_batch)
          eval_metrics.append(metrics)
          eval_logits.append(logits)
          eval_labels.append(labels)
        eval_metrics = get_metrics(config, eval_metrics)
        # Get eval epoch summary for logging
        eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
        eval_summary = {'eval_' + k: v for k, v in eval_summary.items()}
        writer.write_scalars(epoch, eval_summary)

      if config.algorithm == 'sgmcmc' and ensemble_add_step(step):
        ensemble.append((serialization.to_state_dict(optimizer.target), state))

      if config.algorithm == 'sgmcmc' and ensemble_add_step(
          step) and len(ensemble) >= 1:
        # Gather predictions for this ensemble sample.
        eval_logits = jnp.concatenate(eval_logits, axis=0)
        eval_probs = jax.nn.softmax(eval_logits, axis=-1)
        eval_labels = jnp.concatenate(eval_labels, axis=0)
        # Ensure that labels are consistent between predict runs.
        if ensemble_labels:
          assert jnp.allclose(
              eval_labels,
              ensemble_labels[0]), 'Labels unordered between eval runs.'

        ensemble_logits.append(eval_logits)
        ensemble_probs.append(eval_probs)
        ensemble_labels.append(eval_labels)

        # Compute ensemble predictions over last config.ensemble_size samples.
        ensemble_last_probs = jnp.mean(
            jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0)
        ensemble_metrics = train_functions.compute_metrics_probs(
            ensemble_last_probs, ensemble_labels[0])
        ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics)
        ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()}
        ensemble_summary['ensemble_size'] = min(config.ensemble_size,
                                                len(ensemble_probs))
        writer.write_scalars(epoch, ensemble_summary)

      epoch += 1

  return ensemble, optimizer
def train_and_evaluate(config, workdir, strategy):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    strategy: Distribution strategy to use for distributing the model.
  """
    tf.io.gfile.makedirs(workdir)

    tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0),
                                                              2)
    tf.random.set_seed(tf_rng.numpy()[0])

    # Input pipeline.
    ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets(
        config, data_rng, strategy=strategy)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = (ds_info.splits["train"].num_examples //
                           config.global_batch_size * config.num_epochs)
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    # We treat the learning rate in the config as the learning rate for batch size
    # 256 but scale it according to our batch size.
    base_learning_rate = config.learning_rate * config.global_batch_size / 256.0

    # Initialize model.
    num_classes = ds_info.features["label"].num_classes

    if config.distill_teacher:
        do_distill = True
        teacher_file_list = (config.distill_teacher).split(",")
        teacher_models = load_teacher_models(teacher_file_list, num_classes,
                                             config, strategy)
        distill_params = {}
        distill_params["alpha"] = config.distill_alpha
        distill_params["beta"] = config.distill_fd_beta
        distill_params["teacher_model"] = TeacherModel(teacher_models,
                                                       name="teacher")
    else:
        do_distill = False
        distill_params = None

    state = create_state(config, num_classes=num_classes, strategy=strategy)

    ckpt_manager = tf.train.CheckpointManager(checkpoint=state,
                                              directory=workdir,
                                              max_to_keep=5)

    if ckpt_manager.latest_checkpoint:
        state.restore(ckpt_manager.latest_checkpoint)
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")
    initial_step = state.global_step.numpy().item()

    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    writer = metric_writers.create_default_writer(workdir)
    writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            state.model.trainable = True

            # `step` is a Python integer. `global_step` is a TF variable on the
            # GPU/TPU devices.
            is_last_step = step == num_train_steps

            train_step(state, train_iter, config.weight_decay,
                       learning_rate_fn, do_distill, distill_params, strategy)

            state.train_metrics.update_state_lr(
                learning_rate_fn(state.global_step.numpy().item()))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            report_progress(step)

            if step == initial_step:
                parameter_overview.log_parameter_overview(state.model)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, state.train_metrics.result())
                state.train_metrics.reset_states()
                state.train_metrics.reset_lr()

            if step % config.eval_every_steps == 0 or is_last_step:
                state.model.trainable = False
                if config.dataset == "imagenet-lt":
                    evaluate(state, val_ds, state.val_metrics, strategy)
                    writer.write_scalars(step, state.val_metrics.result())
                    logging.info("Num val images %d",
                                 state.val_metrics.accuracy.count.numpy())

                evaluate(state, test_ds, state.test_metrics, strategy)
                writer.write_scalars(step, state.test_metrics.result())

                logging.info("Num test images %d",
                             state.test_metrics.accuracy.count.numpy())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                checkpoint_path = ckpt_manager.save(step)
                logging.info("Saved checkpoint %s", checkpoint_path)

    logging.info("Finishing training at step %d", step)
    logging.info("Saving the final weights")
    file_path = "%s/final_weights" % workdir
    state.model.save_weights(file_path, save_format="tf")
Exemple #11
0
def main(argv):
  del argv  # unused arg

  config = FLAGS.config

  # Unpack total and warmup steps
  # TODO(nband): revert this to separate arguments.
  total_steps = config.total_and_warmup_steps[0]
  warmup_steps = config.total_and_warmup_steps[1]
  del config.total_and_warmup_steps
  config.total_steps = total_steps
  config.lr.warmup_steps = warmup_steps

  # Wandb and Checkpointing Setup
  output_dir = FLAGS.output_dir
  wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
  tf.io.gfile.makedirs(output_dir)
  logging.info('Saving checkpoints at %s', output_dir)

  # Dataset Split Flags
  dist_shift = config.distribution_shift
  print(f'Distribution Shift: {dist_shift}.')
  dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift)

  # LR / Optimization Flags
  batch_size = config.batch_size
  grad_clip_norm = config.grad_clip_norm
  weight_decay = config.weight_decay
  print('Standard wandb hyperparameters:')
  print({
      'batch_size': batch_size,
      'grad_clip_norm': grad_clip_norm,
      'weight_decay': weight_decay,
      'total_steps': config.total_steps,
      'lr': config.lr
  })
  print('SNGP Params:', config.gp_layer)

  # Reweighting loss for class imbalance
  # class_reweight_mode = config.class_reweight_mode
  # if class_reweight_mode == 'constant':
  #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  # else:
  #   class_weights = None

  # Shows the number of available devices.
  # In a CPU/GPU runtime this will be a single device.
  # In a TPU runtime this will be 8 cores.
  print('Number of Jax local devices:', jax.local_devices())

  # TODO(nband): fix sigmoid loss issues.
  assert config.get('loss', None) == 'softmax_xent'

  seed = config.seed
  rng = jax.random.PRNGKey(seed)
  tf.random.set_seed(seed)

  if config.get('data_dir'):
    logging.info('data_dir=%s', config.data_dir)
  logging.info('Output dir: %s', output_dir)

  save_checkpoint_path = None
  if config.get('checkpoint_steps'):
    tf.io.gfile.makedirs(output_dir)
    save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

  # Create an asynchronous multi-metric writer.
  writer = metric_writers.create_default_writer(
      output_dir, just_logging=jax.process_index() > 0)

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  def write_note(note):
    if jax.process_index() == 0:
      logging.info('NOTE: %s', note)

  write_note('Initializing...')

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get('keep_checkpoint_steps'):
    assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
        f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

  batch_size_eval = config.get('batch_size_eval', batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must '
                     f'be divisible by device number ({jax.device_count()})')

  local_batch_size = batch_size // jax.process_count()
  local_batch_size_eval = batch_size_eval // jax.process_count()
  logging.info(
      'Global batch size %d on %d hosts results in %d local batch size. '
      'With %d dev per host (%d dev total), that is a %d per-device batch size.',
      batch_size,
      jax.process_count(), local_batch_size, jax.local_device_count(),
      jax.device_count(), local_batch_size // jax.local_device_count())

  write_note('Initializing preprocessing function...')
  # Same preprocessing function for training and evaluation
  preproc_fn = preprocess_spec.parse(
      spec=config.pp_train, available_ops=preprocess_utils.all_ops())

  write_note('Initializing train dataset...')
  rng, train_ds_rng = jax.random.split(rng)
  train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
  train_base_dataset = ub.datasets.get(
      dataset_names['in_domain_dataset'],
      split=split_names['train_split'],
      data_dir=config.get('data_dir'))
  train_dataset_builder = train_base_dataset._dataset_builder  # pylint: disable=protected-access
  train_ds = input_utils.get_data(
      dataset=train_dataset_builder,
      split=split_names['train_split'],
      rng=train_ds_rng,
      process_batch_size=local_batch_size,
      preprocess_fn=preproc_fn,
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch_size=config.get('prefetch_to_host', 2),
      data_dir=config.get('data_dir'))
  logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:])

  # Start prefetching already.
  train_iter = input_utils.start_input_pipeline(
      train_ds, config.get('prefetch_to_device', 1))

  write_note('Initializing val dataset(s)...')

  # Load in-domain and OOD validation and/or test datasets.
  # Please specify the desired shift (Country Shift or Severity Shift)
  # in the config.
  eval_iter_splits = vit_utils.init_evaluation_datasets(
      use_validation=config.use_validation,
      use_test=config.use_test,
      dataset_names=dataset_names,
      split_names=split_names,
      config=config,
      preproc_fn=preproc_fn,
      batch_size_eval=batch_size_eval,
      local_batch_size_eval=local_batch_size_eval)

  ntrain_img = input_utils.get_num_examples(
      train_dataset_builder,
      split=split_names['train_split'],
      process_batch_size=local_batch_size,
      data_dir=config.get('data_dir'))
  steps_per_epoch = ntrain_img / batch_size

  if config.get('num_epochs'):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get('total_steps'), 'Set either num_epochs or total_steps'
  else:
    total_steps = config.total_steps

  logging.info('Total train data points: %d', ntrain_img)
  logging.info(
      'Running for %d steps, that means %f epochs and %d steps per epoch',
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

  write_note('Initializing model...')
  logging.info('config.model = %s', config.get('model'))

  # Specify Gaussian process layer configs.
  gp_config = config.get('gp_layer', {})
  model_dict = vit_utils.initialize_model('sngp', config)
  model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer']

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    image_size = tuple(train_ds.element_spec['image'].shape[2:])
    logging.info('image_size = %s', image_size)
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    variables = model.init(rng, dummy_input, train=False)
    # Split model parameters into trainable and untrainable collections.
    states, params = variables.pop('params')
    del variables

    # Set bias in the head to a low value, such that loss is small initially.
    params = flax.core.unfreeze(params)
    if use_gp_layer:
      # Modify the head parameter in the GP head.
      params['head']['output_layer']['bias'] = jnp.full_like(
          params['head']['output_layer']['bias'],
          config.get('init_head_bias', 0))
    else:
      params['head']['bias'] = jnp.full_like(
          params['head']['bias'], config.get('init_head_bias', 0))

    return params, states

  rng, rng_init = jax.random.split(rng)
  params_cpu, states_cpu = init(rng_init)

  if jax.process_index() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    writer.write_scalars(step=0, scalars={'num_params': num_params})

  @functools.partial(jax.pmap, axis_name='batch')
  def evaluation_fn(params, states, images, labels):
    variable_dict = {'params': flax.core.freeze(params), **states}
    logits, out = model.apply(
        variable_dict,
        images,
        train=False,
        mean_field_factor=gp_config.get('mean_field_factor', -1.))
    losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses, axis_name='batch')
    top1_idx = jnp.argmax(logits, axis=1)

    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]

    ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
    n = batch_size_eval
    metric_args = jax.lax.all_gather([
        logits, labels, out['pre_logits']], axis_name='batch')
    return ncorrect, loss, n, metric_args

  # Load the optimizer from flax.
  opt_name = config.get('optim_name')
  write_note(f'Initializing {opt_name} optimizer...')
  opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  weight_decay_rules = config.get('weight_decay', []) or []
  rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1.
  weight_decay_fn = train_utils.get_weight_decay_fn(
      weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

  @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
  def update_fn(opt, states, lr, reset_covmat, images, labels, rng):
    """Update step."""
    measurements = {}

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch'))

    def loss_fn(params, states, images, labels):
      # Specify mutable collection to update untrainable GP parameters.
      variable_dict = {'params': flax.core.freeze(params), **states}
      model_results, updated_states = model.apply(
          variable_dict,
          images,
          train=True,
          rngs={'dropout': rng_model_local},
          mutable=list(states.keys()),
          mean_field_factor=gp_config.get('mean_field_factor', -1.))

      logits, _ = model_results
      loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
          logits=logits, labels=labels)
      return loss, updated_states

    # Performs exact covariance update (i.e., reset precision matrix resetting
    # at begining of new epoch) if covmat_momentum is a null value.
    if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0:
      # Resets precision matrix to Identity * ridge_penalty if at the begining
      # of a new epoch. This should be done before accumulate gradient.
      ridge_penalty = gp_config.get('ridge_penalty', 1.)
      prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix']
      prec_mat_new = (
          (1. - reset_covmat) * prec_mat_old +
          reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty)

      states = flax.core.unfreeze(states)
      states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix'] = prec_mat_new
      states = flax.core.freeze(states)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    (l, s), g = vit_utils.accumulate_gradient_with_states(
        jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images,
        labels, config.get('grad_accum_steps'))
    l, g = jax.lax.pmean((l, g), axis_name='batch')

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None:
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements['l2_grads'] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if grad_clip_norm is not None:
      g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)
    opt = opt.replace(target=weight_decay_fn(opt.target, lr))

    params, _ = jax.tree_flatten(opt.target)
    measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))
    measurements['reset_covmat'] = reset_covmat

    return opt, s, l, rng, measurements

  # Set config checkpoint resume path, if provided in args.
  if config.resume_checkpoint_path is not None:
    config.resume = config.resume_checkpoint_path

  default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias',
                           'head/kernel', 'head/bias')
  rng, train_loop_rngs = jax.random.split(rng)
  checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
      train_loop_rngs=train_loop_rngs,
      save_checkpoint_path=save_checkpoint_path,
      init_optimizer=opt_cpu,
      init_params=params_cpu,
      init_fixed_model_states=states_cpu,
      default_reinit_params=default_reinit_params,
      config=config)
  train_loop_rngs = checkpoint_data.train_loop_rngs
  opt_cpu = checkpoint_data.optimizer
  states_cpu = checkpoint_data.fixed_model_states
  accumulated_train_time = checkpoint_data.accumulated_train_time

  write_note('Adapting the checkpoint model...')
  adapted_params = checkpoint_utils.adapt_upstream_architecture(
      init_params=params_cpu,
      loaded_params=opt_cpu.target)
  opt_cpu = opt_cpu.replace(target=adapted_params)

  write_note('Kicking off misc stuff...')
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  if first_step == 0 and jax.process_index() == 0:
    writer.write_hparams(dict(config))
  chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                              accumulated_train_time)
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=output_dir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                    **config.get('lr', {}))

  # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
  # necessary for TPUs.
  lr_iter = train_utils.prefetch_scalar(
      map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1))

  # Prepare the precision matrix resetting schedule, and pre-fetch it to device.
  reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0)
  reset_covmat_iter = train_utils.prefetch_scalar(
      map(reset_covmat_fn, range(first_step, total_steps)),
      nprefetch=config.get('prefetch_to_device', 1))

  write_note(f'Replicating...\n{chrono.note}')
  opt_repl = flax.jax_utils.replicate(opt_cpu)
  states_repl = flax.jax_utils.replicate(states_cpu)

  checkpoint_writer = None

  # Note: we return the train loss, val loss, and fewshot best l2s for use in
  # reproducibility unit tests.
  # train_loss = -jnp.inf
  # val_loss = -jnp.inf
  # results = {'dummy': {(0, 1): -jnp.inf}}

  write_note(f'First step compilations...\n{chrono.note}')
  logging.info('first_step = %s', first_step)
  # Advance the iterators if we are restarting from an earlier checkpoint.
  # TODO(dusenberrymw): Look into checkpointing dataset state instead.

  # Makes sure log_eval_steps is same as steps_per_epoch. This is because
  # the precision matrix needs to be updated fully (at the end of each epoch)
  # when eval takes place.
  log_eval_steps = steps_per_epoch
  if first_step > 0:
    write_note('Advancing iterators after resuming from a checkpoint...')
    lr_iter = itertools.islice(lr_iter, first_step, None)
    train_iter = itertools.islice(train_iter, first_step, None)

  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl, reset_covmat_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter,
      reset_covmat_iter):

    with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
      # TODO(jereliu): Expand to allow precision matrix resetting.
      (opt_repl, states_repl, loss_value, train_loop_rngs,
       extra_measurements) = update_fn(
           opt_repl,
           states_repl,
           lr_repl,
           reset_covmat_repl,
           train_batch['image'],
           train_batch['labels'],
           rng=train_loop_rngs)

    if jax.process_index() == 0:
      profiler(step)

    # Checkpoint saving
    if train_utils.itstime(
        step, config.get('checkpoint_steps'), total_steps, process=0):
      write_note('Checkpointing...')
      chrono.pause()
      train_utils.checkpointing_timeout(checkpoint_writer,
                                        config.get('checkpoint_timeout', 1))
      accumulated_train_time = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      # For GP layer, we will also do the same for untrainable parameters
      # (`states`). This is ok since `random features` are frozen throughout
      # pre-training, and `precision matrix` is a finetuning-specific parameters
      # that will be re-learned in the finetuning task.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
      states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                             total_steps):
        write_note('Keeping a checkpoint copy...')
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint_data = checkpoint_utils.CheckpointData(
          optimizer=opt_cpu,
          fixed_model_states=states_cpu,
          train_loop_rngs=train_loop_rngs,
          accumulated_train_time=accumulated_train_time)
      checkpoint_writer = pool.apply_async(
          checkpoint_utils.checkpoint_trained_model,
          (checkpoint_data, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if train_utils.itstime(
        step, config.log_training_steps, total_steps, process=0):
      write_note('Reporting training progress...')
      train_loss = loss_value[0]  # Keep to return for reproducibility tests.
      timing_measurements, note = chrono.tick(step)
      write_note(note)
      train_measurements = {}
      train_measurements.update({
          'learning_rate': lr_repl[0],
          'training_loss': train_loss,
      })
      train_measurements.update(flax.jax_utils.unreplicate(extra_measurements))
      train_measurements.update(timing_measurements)
      writer.write_scalars(step, train_measurements)

    # Report validation performance
    if train_utils.itstime(step, log_eval_steps, total_steps):
      write_note('Evaluating on the validation set...')
      chrono.pause()

      all_eval_results = {}

      for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
        start_time = time.time()

        # Runs evaluation loop.
        results_arrs = {
            'y_true': [],
            'y_pred': [],
            'y_pred_entropy': []
        }

        for _, batch in zip(range(eval_steps), eval_iter):
          batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
              evaluation_fn(
                  opt_repl.target, states_repl, batch['image'],
                  batch['labels']))

          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.

          # Here we parse batch_metric_args to compute uncertainty metrics.
          logits, labels, _ = batch_metric_args
          logits = np.array(logits[0])
          probs = jax.nn.softmax(logits)

          # From one-hot to integer labels.
          int_labels = np.argmax(np.array(labels[0]), axis=-1)

          probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1))
          int_labels = int_labels.flatten()
          y_pred = probs[:, 1]
          results_arrs['y_true'].append(int_labels)
          results_arrs['y_pred'].append(y_pred)

          # Entropy is computed at the per-epoch level (see below).
          results_arrs['y_pred_entropy'].append(probs)

        results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                axis=0)
        results_arrs['y_pred'] = np.concatenate(
            results_arrs['y_pred'], axis=0).astype('float64')
        results_arrs['y_pred_entropy'] = vit_utils.entropy(
            np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1)

        time_elapsed = time.time() - start_time
        results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
        results_arrs['dataset_size'] = eval_steps * batch_size_eval

        all_eval_results[eval_name] = results_arrs

      per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
          dataset_split_to_containers=all_eval_results,
          is_deterministic=True,
          num_bins=15,
          return_per_pred_results=True
      )

      # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
      # dataset. Flatten this dict so we can pass to the writer and remove empty
      # entries.
      flattened_metric_results = {}
      for dic in metrics_results.values():
        for key, value in dic.items():
          if value is not None:
            flattened_metric_results[key] = value
      writer.write_scalars(step, flattened_metric_results)

      # Optionally log to wandb
      if config.use_wandb:
        wandb.log(metrics_results, step=step)

      # Save per-prediction metrics
      results_storage_utils.save_per_prediction_results(
          output_dir, step, per_pred_results, verbose=False)

      chrono.resume()

      # End of step.
    if config.get('testing_failure_step'):
      # Break early to simulate infra failures in test cases.
      if config.testing_failure_step == step:
        break

  write_note(f'Done!\n{chrono.note}')
  pool.close()
  pool.join()
  writer.close()

  if wandb_run is not None:
    wandb_run.finish()
Exemple #12
0
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    # Seed for reproducibility.
    rng = jax.random.PRNGKey(config.rng_seed)

    # Set up logging.
    summary_writer = metric_writers.create_default_writer(workdir)
    summary_writer.write_hparams(dict(config))

    # Get datasets.
    rng, dataset_rng = jax.random.split(rng)
    dataset = input_pipeline.get_dataset(config, dataset_rng)
    graph, labels, masks = jax.tree_map(jnp.asarray, dataset)
    labels = jax.nn.one_hot(labels, config.num_classes)
    train_mask = masks['train']
    train_indices = jnp.where(train_mask)[0]
    train_labels = labels[train_indices]
    num_training_nodes = len(train_indices)

    # Get subgraphs.
    if config.differentially_private_training:
        graph = jax.tree_map(np.asarray, graph)
        subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to)
        graph = jax.tree_map(jnp.asarray, graph)

        # We only need the subgraphs for training nodes.
        train_subgraphs = subgraphs[train_indices]
        del subgraphs
    else:
        train_subgraphs = None

    # Initialize privacy accountant.
    training_privacy_accountant = privacy_accountants.get_training_privacy_accountant(
        config, num_training_nodes, compute_max_terms_per_node(config))

    # Construct and initialize model.
    rng, init_rng = jax.random.split(rng)
    estimation_indices = get_estimation_indices(train_indices, config)
    state = create_train_state(init_rng, config, graph, train_labels,
                               train_subgraphs, estimation_indices)

    # Set up checkpointing of the model.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Log overview of parameters.
    parameter_overview.log_parameter_overview(state.params)

    # Log metrics after initialization.
    logits = compute_logits(state, graph)
    metrics_after_init = compute_metrics(logits, labels, masks)
    metrics_after_init['epsilon'] = 0
    log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init')

    # Train model.
    rng, train_rng = jax.random.split(rng)
    max_training_epsilon = get_max_training_epsilon(config)

    # Hooks called periodically during training.
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_training_steps, writer=summary_writer)
    profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
    hooks = [report_progress, profiler]

    for step in range(initial_step, config.num_training_steps):

        # Perform one step of training.
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
            # Sample batch.
            step_rng = jax.random.fold_in(train_rng, step)
            indices = jax.random.choice(step_rng, num_training_nodes,
                                        (config.batch_size, ))

            # Compute gradients.
            if config.differentially_private_training:
                grads = compute_updates_for_dp(state, graph, train_labels,
                                               train_subgraphs, indices,
                                               config.adjacency_normalization)
            else:
                grads = compute_updates(state, graph, train_labels, indices)

            # Update parameters.
            state = update_model(state, grads)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 10,
                            step)
        for hook in hooks:
            hook(step)

        # Evaluate, if required.
        is_last_step = (step == config.num_training_steps - 1)
        if step % config.evaluate_every_steps == 0 or is_last_step:
            with report_progress.timed('eval'):
                # Check if privacy budget exhausted.
                training_epsilon = training_privacy_accountant(step + 1)
                if max_training_epsilon is not None and training_epsilon >= max_training_epsilon:
                    break

                # Compute metrics.
                logits = compute_logits(state, graph)
                metrics_during_training = compute_metrics(
                    logits, labels, masks)
                metrics_during_training['epsilon'] = training_epsilon
                log_metrics(step, metrics_during_training, summary_writer)

        # Checkpoint, if required.
        if step % config.checkpoint_every_steps == 0 or is_last_step:
            with report_progress.timed('checkpoint'):
                ckpt.save(state)

    return state
Exemple #13
0
def main(config, output_dir):
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=10)

    logging.info(config)

    acquisition_method = config.get('acquisition_method')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)
    writer.write_hparams(dict(config))

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note(f'Initializing for {acquisition_method}')

    # Download dataset
    data_builder = tfds.builder(config.dataset)
    data_builder.download_and_prepare()

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()

    val_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.val_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    test_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.test_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    # Init model
    if config.model_type == 'deterministic':
        model_utils = deterministic_utils
        reinit_params = config.get('model_reinit_params',
                                   ('head/kernel', 'head/bias'))
        model = ub.models.vision_transformer(num_classes=config.num_classes,
                                             **config.get('model', {}))
    elif config.model_type == 'batchensemble':
        model_utils = batchensemble_utils
        reinit_params = ('batchensemble_head/bias',
                         'batchensemble_head/kernel',
                         'batchensemble_head/fast_weight_alpha',
                         'batchensemble_head/fast_weight_gamma')
        model = ub.models.PatchTransformerBE(num_classes=config.num_classes,
                                             **config.model)
    else:
        raise ValueError('Expect config.model_type to be "deterministic" or'
                         f'"batchensemble", but received {config.model_type}.')

    init = model_utils.create_init(model, config, test_ds)

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    loaded_params = checkpoint_utils.load_checkpoint(tree=None,
                                                     path=config.model_init)
    loaded = checkpoint_utils.restore_from_pretrained_params(
        params_cpu,
        loaded_params,
        config.model.representation_size,
        config.model.classifier,
        reinit_params,
    )

    opt_cpu = opt_cpu.replace(target=loaded)

    # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being
    # donated otherwise. Ensure opt_cpu is really on the cpu this way.
    opt_cpu = jax.device_get(opt_cpu)

    update_fn = model_utils.create_update_fn(model, config)
    evaluation_fn = model_utils.create_evaluation_fn(model, config)

    # NOTE: We need this because we need an Id field of type int.
    # TODO(andreas): Rename to IdSubsetDatasetBuilder?
    pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder,
                                                             subset_ids=None)

    rng, pool_ds_rng = jax.random.split(rng)

    # NOTE: below line is necessary on multi host setup
    # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index())

    pool_train_ds = input_utils.get_data(
        dataset=pool_subset_data_builder,
        split=config.train_split,
        rng=pool_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        drop_remainder=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Don't repeat
    )

    # Potentially acquire an initial training set.
    initial_training_set_size = config.get('initial_training_set_size', 10)

    if initial_training_set_size > 0:
        current_opt_repl = flax_utils.replicate(opt_cpu)
        pool_ids, _, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            config=config)

        rng, initial_uniform_rng = jax.random.split(rng)
        pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng)

        initial_training_set_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=initial_training_set_size,
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=set(),
        )
    else:
        initial_training_set_batch_ids = []

    # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU
    # then this dataset creation could be simplified.
    # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340
    # CLU is explicitly not accepting outside contributions at the moment.
    train_subset_data_builder = al_utils.SubsetDatasetBuilder(
        data_builder, subset_ids=set(initial_training_set_batch_ids))

    test_accuracies = []
    training_sizes = []

    rng, rng_loop = jax.random.split(rng)
    rngs_loop = flax_utils.replicate(rng_loop)
    if config.model_type == 'batchensemble':
        rngs_loop = {'dropout': rngs_loop}

    # TODO(joost,andreas): double check if below is still necessary
    # (train_split is independent of this)
    # NOTE: train_ds_rng is re-used for all train_ds creations
    rng, train_ds_rng = jax.random.split(rng)

    measurements = {}
    accumulated_steps = 0
    while True:
        current_train_ds_length = len(train_subset_data_builder.subset_ids)
        if current_train_ds_length >= config.get('max_training_set_size', 150):
            break
        write_note(f'Training set size: {current_train_ds_length}')

        current_opt_repl = flax_utils.replicate(opt_cpu)

        # Only fine-tune if there is anything to fine-tune with.
        if current_train_ds_length > 0:
            # Repeat dataset to have oversampled epochs and bootstrap more batches
            number_of_batches = current_train_ds_length / config.batch_size
            num_repeats = math.ceil(config.total_steps / number_of_batches)
            write_note(f'Repeating dataset {num_repeats} times')

            # We repeat the dataset several times, such that we can obtain batches
            # of size batch_size, even at start of training. These batches will be
            # effectively 'bootstrap' sampled, meaning they are sampled with
            # replacement from the original training set.
            repeated_train_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_train,
                    available_ops=preprocess_utils.all_ops()),
                shuffle_buffer_size=config.shuffle_buffer_size,
                prefetch_size=config.get('prefetch_to_host', 2),
                # TODO(joost,andreas): double check if below leads to bootstrap
                # sampling.
                num_epochs=num_repeats,
            )

            # We use this dataset to evaluate how well we perform on the training set.
            # We need this to evaluate if we fit well within max_steps budget.
            train_eval_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_eval,
                    available_ops=preprocess_utils.all_ops()),
                shuffle=False,
                drop_remainder=False,
                prefetch_size=config.get('prefetch_to_host', 2),
                num_epochs=1,
            )

            # NOTE: warmup and decay are not a good fit for the small training set
            # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps,
            #                                                   **config.get('lr', {})
            #                                                   )
            lr_fn = lambda x: config.lr.base

            early_stopping_patience = config.get('early_stopping_patience', 15)
            current_opt_repl, rngs_loop, measurements = finetune(
                update_fn=update_fn,
                opt_repl=current_opt_repl,
                lr_fn=lr_fn,
                ds=repeated_train_ds,
                rngs_loop=rngs_loop,
                total_steps=config.total_steps,
                train_eval_ds=train_eval_ds,
                val_ds=val_ds,
                evaluation_fn=evaluation_fn,
                early_stopping_patience=early_stopping_patience,
                profiler=profiler)
            train_val_accuracies = measurements.pop('train_val_accuracies')
            current_steps = 0
            for step, train_acc, val_acc in train_val_accuracies:
                writer.write_scalars(accumulated_steps + step, {
                    'train_accuracy': train_acc,
                    'val_accuracy': val_acc
                })
                current_steps = step
            accumulated_steps += current_steps + 10

        test_accuracy = get_accuracy(evaluation_fn=evaluation_fn,
                                     opt_repl=current_opt_repl,
                                     ds=test_ds)

        write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}')

        test_accuracies.append(test_accuracy)
        training_sizes.append(current_train_ds_length)

        pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            use_pre_logits=acquisition_method == 'density',
            config=config)

        if acquisition_method == 'uniform':
            rng_loop, rng_acq = jax.random.split(rng_loop, 2)
            pool_scores = get_uniform_scores(pool_masks, rng_acq)
        elif acquisition_method == 'entropy':
            pool_scores = get_entropy_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'margin':
            pool_scores = get_margin_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'density':
            if current_train_ds_length > 0:
                pool_scores = get_density_scores(model=model,
                                                 opt_repl=current_opt_repl,
                                                 train_ds=train_eval_ds,
                                                 pool_pre_logits=pool_outputs,
                                                 pool_masks=pool_masks,
                                                 config=config)
            else:
                rng_loop, rng_acq = jax.random.split(rng_loop, 2)
                pool_scores = get_uniform_scores(pool_masks, rng_acq)
        else:
            raise ValueError('Acquisition method not found.')

        acquisition_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=config.get('acquisition_batch_size', 10),
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=train_subset_data_builder.subset_ids)

        train_subset_data_builder.subset_ids.update(acquisition_batch_ids)

        measurements.update({'test_accuracy': test_accuracy})
        writer.write_scalars(current_train_ds_length, measurements)

    write_note(f'Final acquired training ids: '
               f'{train_subset_data_builder.subset_ids}'
               f'Accuracies: {test_accuracies}')

    pool.close()
    pool.join()
    writer.close()
    # TODO(joost,andreas): save the final checkpoint
    return (train_subset_data_builder.subset_ids, test_accuracies)
def main(_):
    config = FLAGS.config

    # Unpack total and warmup steps
    total_steps = config.total_and_warmup_steps[0]
    warmup_steps = config.total_and_warmup_steps[1]
    del config.total_and_warmup_steps
    config.total_steps = total_steps
    config.lr.warmup_steps = warmup_steps

    # Wandb and Checkpointing Setup
    output_dir = FLAGS.output_dir
    wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
    tf.io.gfile.makedirs(output_dir)
    logging.info('Saving checkpoints at %s', output_dir)

    # Dataset Split Flags
    dist_shift = config.distribution_shift
    print(f'Distribution Shift: {dist_shift}.')
    dataset_names, split_names = vit_utils.get_dataset_and_split_names(
        dist_shift)

    # LR / Optimization Flags
    print('wandb hyperparameters:')
    print({
        'batch_size': config.batch_size,
        'grad_clip_norm': config.grad_clip_norm,
        'weight_decay': config.weight_decay,
        'total_steps': config.total_steps,
        'lr': config.lr,
        'fast_weight_lr_multiplier': config.fast_weight_lr_multiplier
    })

    # Reweighting loss for class imbalance
    # class_reweight_mode = config.class_reweight_mode
    # if class_reweight_mode == 'constant':
    #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
    # else:
    #   class_weights = None

    # Shows the number of available devices.
    # In a CPU/GPU runtime this will be a single device.
    # In a TPU runtime this will be 8 cores.
    print('Number of Jax local devices:', jax.local_devices())

    # TODO(nband): fix sigmoid loss issues.
    assert config.get('loss', None) == 'softmax_xent'

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)
    tf.io.gfile.makedirs(output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing preprocessing function...')
    # Same preprocessing function for training and evaluation
    preproc_fn = preprocess_spec.parse(
        spec=config.pp_train, available_ops=preprocess_utils.all_ops())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_base_dataset = ub.datasets.get(dataset_names['in_domain_dataset'],
                                         split=split_names['train_split'],
                                         data_dir=config.get('data_dir'))
    train_dataset_builder = train_base_dataset._dataset_builder  # pylint:disable=protected-access
    train_ds = input_utils.get_data(
        dataset=train_dataset_builder,
        split=split_names['train_split'],
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preproc_fn,
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    # Load in-domain and OOD validation and/or test datasets.
    # Please specify the desired shift (Country Shift or Severity Shift)
    # in the config.
    eval_iter_splits = vit_utils.init_evaluation_datasets(
        use_validation=config.use_validation,
        use_test=config.use_test,
        dataset_names=dataset_names,
        split_names=split_names,
        config=config,
        preproc_fn=preproc_fn,
        batch_size_eval=batch_size_eval,
        local_batch_size_eval=local_batch_size_eval)

    ntrain_img = input_utils.get_num_examples(
        train_dataset_builder,
        split=split_names['train_split'],
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = ntrain_img // batch_size

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    model_dict = vit_utils.initialize_model('batchensemble', config)
    model = model_dict['model']
    ens_size = model_dict['ens_size']

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @functools.partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['batchensemble_head']['bias'] = jnp.full_like(
            params['batchensemble_head']['bias'],
            config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['batchensemble_head']['kernel'] = jnp.full_like(
                params['batchensemble_head']['kernel'], 0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @functools.partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels):
        tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                        images,
                                        train=False)

        loss_name = config.get('loss', 'sigmoid_xent')
        # TODO(dusenberrymw,zmariet): Clean up and generalize this.
        if loss_name == 'sigmoid_xent':
            ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
        else:  # softmax
            ens_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))

        losses = getattr(train_utils,
                         loss_name)(logits=ens_logits,
                                    labels=labels[:, :config.num_classes],
                                    reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = batch_size_eval

        metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits],
                                         axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    def batch_loss_fn(params, images, labels, rngs):
        logits, _ = model.apply({'params': flax.core.freeze(params)},
                                images,
                                train=True,
                                rngs=rngs)
        labels = jnp.tile(labels, (ens_size, 1))
        loss_fn = getattr(train_utils, config.get('loss', 'sigmoid_xent'))
        loss = jnp.mean(loss_fn(logits=logits, labels=labels))
        return loss, dict()

    @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0, 1))
    def update_fn(opt, rngs, lr, images, labels):
        return batchensemble_utils.update_fn_be(
            opt=opt,
            rngs=rngs,
            lr=lr,
            images=images,
            labels=labels,
            batch_loss_fn=batch_loss_fn,
            weight_decay_fn=weight_decay_fn,
            max_grad_norm_global=config.get('grad_clip_norm', None),
            fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier',
                                                 None))

    # Set config checkpoint resume path, if provided in args.
    if config.resume_checkpoint_path is not None:
        config.resume = config.resume_checkpoint_path

    reint_params = ('batchensemble_head/bias', 'batchensemble_head/kernel',
                    'batchensemble_head/fast_weight_alpha',
                    'batchensemble_head/fast_weight_gamma')
    if config.get('only_eval', False) or not config.get('reint_head', True):
        reint_params = []
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=rng,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=reint_params,
        config=config)
    train_loop_rngs = {'dropout': checkpoint_data.train_loop_rngs}
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))

    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax.jax_utils.replicate(opt_cpu)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    # train_loss = -jnp.inf
    # eval_loss = {
    #     eval_name: -jnp.inf for eval_name, _ in eval_iter_splits.items()}
    # fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        # TODO(zmariet): Find better way to cut down iteration advancement cost.
        if not config.get('disable_preemption_reproducibility', False):
            train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl, train_loop_rngs, lr_repl, train_batch['image'],
                    train_batch['labels'])

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                train_loop_rngs=train_loop_rngs,
                optimizer=opt_cpu,
                accumulated_train_time=accumulated_train_time)

            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)
            # Keep to return for reproducibility tests.
            # train_loss = train_measurements['training_loss']

        # Report validation performance
        if config.get('only_eval', False) or train_utils.itstime(
                step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation sets...')
            chrono.pause()

            all_eval_results = {}

            for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
                start_time = time.time()

                # Runs evaluation loop.
                results_arrs = {
                    'y_true': [],
                    'y_pred': [],
                    'y_pred_entropy': []
                }

                for _, batch in zip(range(eval_steps), eval_iter):
                    batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
                        evaluation_fn(opt_repl.target, batch['image'],
                                      batch['labels']))

                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.

                    # Here we parse batch_metric_args to compute uncertainty metrics.
                    logits, labels, _ = batch_metric_args
                    logits = np.array(logits[0])
                    probs = jax.nn.softmax(logits)

                    # From one-hot to integer labels.
                    int_labels = np.argmax(np.array(labels[0]), axis=-1)

                    probs = np.reshape(probs,
                                       (probs.shape[0] * probs.shape[1], -1))
                    int_labels = int_labels.flatten()
                    y_pred = probs[:, 1]
                    results_arrs['y_true'].append(int_labels)
                    results_arrs['y_pred'].append(y_pred)

                    # Entropy is computed at the per-epoch level (see below).
                    results_arrs['y_pred_entropy'].append(probs)

                results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                        axis=0)
                results_arrs['y_pred'] = np.concatenate(
                    results_arrs['y_pred'], axis=0).astype('float64')
                results_arrs['y_pred_entropy'] = vit_utils.entropy(
                    np.concatenate(results_arrs['y_pred_entropy'], axis=0),
                    axis=-1)

                time_elapsed = time.time() - start_time
                results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
                results_arrs['dataset_size'] = eval_steps * batch_size_eval

                all_eval_results[eval_name] = results_arrs

            per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
                dataset_split_to_containers=all_eval_results,
                is_deterministic=True,
                num_bins=15,
                return_per_pred_results=True)

            # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
            # dataset. Flatten this dict so we can pass to the writer and remove empty
            # entries.
            flattened_metric_results = {}
            for dic in metrics_results.values():
                for key, value in dic.items():
                    if value is not None:
                        flattened_metric_results[key] = value
            writer.write_scalars(step, flattened_metric_results)

            # Optionally log to wandb
            if config.use_wandb:
                wandb.log(metrics_results, step=step)

            # Save per-prediction metrics
            results_storage_utils.save_per_prediction_results(output_dir,
                                                              step,
                                                              per_pred_results,
                                                              verbose=False)
            chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

        if config.get('only_eval', False):
            break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()
def maybe_load_checkpoint(train_loop_rngs: jnp.ndarray,
                          save_checkpoint_path: str,
                          init_optimizer: flax.optim.Optimizer,
                          init_params: Params,
                          init_fixed_model_states: Optional[Params],
                          default_reinit_params: Iterable[str],
                          config: ml_collections.ConfigDict) -> CheckpointData:
  """Loads a model from an existing checkpoint if so indicated by the config.

  Whether to resume training, initialize from a previous checkpoint, or do
  nothing is set by the `config` ConfigDict, based on the existence of fields
  `resume` (resume training) or `model_init` (initialize from pretrained
  checkpoint).

  When resuming training, both the model weights and optimizer
  state (including the training step) are restored. When initializing, only
  the model parameters are updated.

  The way in which initializing is prioritized in the following way:
  1. Always resume from an existing checkpoint, e.g. resume a finetune job.
  2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  3. Initialize model from something, e,g, start a fine-tuning job.
  4. Do nothing (training from scratch).

  Args:
    train_loop_rngs: unreplicated jax.PRNGKey.
    save_checkpoint_path: File pointing to pretrained checkpoint stored in NumPy
      `.npz` file.
    init_optimizer: flax.Optimizer to be updated.
    init_params: Tree of (possibly randomly) initialized parameters for the
      model.
    init_fixed_model_states: Optional pytree of non-trainable parameters.
      Currently only passed when using SNGP models.
    default_reinit_params: List of parameter names to reinitialize if not
      provided by the config file.
    config: ConfigDict which contains fields indicating if, and how, to load an
      available checkpoint into the optimizer. If resuming from a previous
      checkpoint *to start a cooldown job*, the flag `resume` must be set. If
      initializing a (subset of) model parameters to start a file tuning job,
      fields `model_init`, `representation_size` and `classifier` must be set.

  Returns:
    A CheckpointData instance containing a new rng key, the new optimizer state,
    the new untrainable parameters (if resuming from a checkpoint), and a
    dictionary of information about the reloaded state.
  """
  optimizer = init_optimizer
  fixed_model_states = init_fixed_model_states

  accum_train_time = 0.0
  # TODO(dusenberrymw, zmariet): Directly return an unreplicated rng and the
  # cumulative training time instead of storing them in `checkpoint_extra`.
  checkpoint_extra = dict(
      accum_train_time=accum_train_time,
      rngs_loop=flax_utils.replicate(train_loop_rngs))

  # Parse config file to figure out which setting we are in.
  resume_from_checkpoint = (
      (save_checkpoint_path is not None and
       tf.io.gfile.exists(save_checkpoint_path))
      or config.get("resume") is not None)
  reinitialize_model = config.get(
      "model_init") is not None and not resume_from_checkpoint

  if resume_from_checkpoint:
    logging.info("Resume training from checkpoint...")
    # Always prioritize loading from a checkpoint from the current training job.
    if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path):
      resume_checkpoint_path = save_checkpoint_path
    # Otherwise, we reload from a previous checkpoint provided by the config.
    else:
      resume_checkpoint_path = config.resume

    checkpoint_tree = {"opt": init_optimizer, "extra": checkpoint_extra}
    if init_fixed_model_states is not None:
      checkpoint_tree["states"] = init_fixed_model_states
    checkpoint = load_checkpoint(checkpoint_tree, resume_checkpoint_path)
    optimizer, checkpoint_extra = checkpoint["opt"], checkpoint["extra"]
    fixed_model_states = checkpoint.get("states", None)

  elif reinitialize_model:
    logging.info("Initialize model...")
    reinit_params = config.get("model_reinit_params", default_reinit_params)
    logging.info("Reinitializing these parameters: %s", reinit_params)

    loader = lambda path: load_checkpoint(tree=None, path=path)
    loaded_params = loader(config.model_init)

    loaded_params = restore_from_pretrained_params(
        init_params=init_params,
        loaded_params=loaded_params,
        model_representation_size=config.model.representation_size,
        model_classifier=config.model.classifier,
        reinit_params=reinit_params)

    optimizer = init_optimizer.replace(target=loaded_params)
    if jax.process_index() == 0:
      logging.info("Restored parameter overview:")
      parameter_overview.log_parameter_overview(loaded_params)

  else:
    logging.info("No checkpoint to recover from; using default initialization.")

  return CheckpointData(
      optimizer=optimizer,
      fixed_model_states=fixed_model_states,
      train_loop_rngs=checkpoint_extra["rngs_loop"],
      accumulated_train_time=checkpoint_extra["accum_train_time"])
def main(argv):
  del argv

  config = FLAGS.config
  workdir = FLAGS.workdir

  logging.info("Workdir: %s", workdir)

  save_checkpoint_path = None
  if config.get("checkpoint_steps"):
    tf.io.gfile.makedirs(workdir)
    save_checkpoint_path = os.path.join(workdir, "checkpoint.npz")

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # This seed makes the Jax part of things (like model init) deterministic.
  # However, full training still won't be deterministic, for example due to the
  # tf.data pipeline not being deterministic even if we would set TF seed.
  rng = jax.random.PRNGKey(config.get("seed", 0))

  def write_note(note):
    if jax.host_id() == 0:
      logging.info("NOTE: %s", note)
  write_note("Initializing...")

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get("keep_checkpoint_steps"):
    assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`."
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be"
        f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`")

  batch_size = config.batch_size
  batch_size_eval = config.get("batch_size_eval", batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must "
                     f"be divisible by device number ({jax.device_count()})")

  local_batch_size = batch_size // jax.host_count()
  local_batch_size_eval = batch_size_eval // jax.host_count()
  logging.info(
      "Global batch size %d on %d hosts results in %d local batch size. "
      "With %d dev per host (%d dev total), that's a %d per-device batch size.",
      batch_size, jax.host_count(), local_batch_size,
      jax.local_device_count(), jax.device_count(),
      local_batch_size // jax.local_device_count())

  write_note("Initializing train dataset...")
  train_ds = input_pipeline.get_data(
      dataset=config.dataset,
      split=config.train_split,
      data_dir=fillin(config.get("dataset_dir")),
      batch_size=local_batch_size,
      preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train),
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch=config.get("prefetch_to_host", 2),
      cache=False)

  # Start prefetching already.
  train_iter = u.start_input_pipeline(
      train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size)
  # We always pad to local_batch_size_eval even when less would be enough in
  # order to minimize memory fragmentation.

  write_note("Initializing val dataset(s)...")
  def _get_val_split(dataset, split, pp_eval, data_dir=None):
    # We do ceil rounding such that we include the last incomplete batch.
    nval_img = input_pipeline.get_num_examples(
        dataset, split, data_dir=fillin(data_dir))
    val_steps = int(np.ceil(nval_img / batch_size_eval))
    logging.info("Running validation for %d steps for %s, %s", val_steps,
                 dataset, split)

    val_it = input_pipeline.get_data(
        dataset=dataset,
        split=split,
        data_dir=fillin(data_dir),
        batch_size=local_batch_size_eval,
        preprocess_fn=pp_builder.get_preprocess_fn(pp_eval),
        cache=config.get("val_cache", "batched"),
        repeat_after_batching=True,
        prefetch=0,  # Save memory since we cache.
        drop_remainder=False,
        shuffle_files=False)
    val_it = u.start_input_pipeline(
        val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval)

    return (val_it, val_steps)

  if isinstance(config.val_split, str):
    val_ds = {"val": _get_val_split(config.dataset, config.val_split,
                                    config.pp_eval, config.get("dataset_dir"))}
  else:
    val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split}

  ntrain_img = input_pipeline.get_num_examples(
      config.dataset, config.train_split,
      data_dir=fillin(config.get("dataset_dir")))
  steps_per_epoch = ntrain_img / batch_size

  if config.get("num_epochs"):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get("total_steps"), "Set either num_epochs or total_steps"
  else:
    total_steps = config.total_steps

  logging.info(
      "Running for %d steps, that means %f epochs and %f steps per epoch",
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)
  mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch)

  write_note(f"Initializing {config.model_name} model...")
  model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}")
  model = model_mod.Model(
      num_classes=config.num_classes, **config.get("model", {}))

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @partial(jax.jit, backend="cpu")
  def init(rng):
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    params = flax.core.unfreeze(model.init(rng, dummy_input))["params"]

    # Set bias in the head to a low value, such that loss is small initially.
    params["head"]["bias"] = jnp.full_like(
        params["head"]["bias"], config.get("init_head_bias", 0))

    return params

  rng, rng_init = jax.random.split(rng)
  params_cpu = init(rng_init)

  if jax.host_id() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    mw.measure("num_params", num_params)

  @partial(jax.pmap, axis_name="batch")
  def evaluation_fn(params, images, labels, mask):
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    logits, _ = model.apply({"params": flax.core.freeze(params)}, images)

    losses = getattr(u, config.get("loss", "sigmoid_xent"))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name="batch")

    top1_idx = jnp.argmax(logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch")
    n = jax.lax.psum(mask, axis_name="batch")
    return ncorrect, loss, n

  # Setup function for computing representation.
  @partial(jax.pmap, axis_name="batch")
  def representation_fn(params, images, labels, mask):
    _, outputs = model.apply({"params": flax.core.freeze(params)}, images)
    representation = outputs[config.fewshot.representation_layer]
    representation = jax.lax.all_gather(representation, "batch")
    labels = jax.lax.all_gather(labels, "batch")
    mask = jax.lax.all_gather(mask, "batch")
    return representation, labels, mask

  # Load the optimizer either from our folder or from flax.
  opt_name = config.get("optim_name", "momentum_hp")
  write_note(f"Initializing {opt_name} optimizer...")
  try:
    opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}")
    opt_def = opt_mod.Optimizer(**config.get("optim", {}))
  except ModuleNotFoundError:
    opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  @partial(jax.pmap, axis_name="batch", donate_argnums=(0,))
  def update_fn(opt, lr, images, labels, rng):
    """Update step."""

    measurements = {}

    if config.get("mixup") and config.mixup.p:
      rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup)

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

    def loss_fn(params, images, labels):
      logits, _ = model.apply(
          {"params": flax.core.freeze(params)}, images,
          train=True, rngs={"dropout": rng_model_local})
      return getattr(u, config.get("loss", "sigmoid_xent"))(
          logits=logits, labels=labels)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target,
                                 images, labels,
                                 config.get("grad_accum_steps"))
    l, g = jax.lax.pmean((l, g), axis_name="batch")

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"):
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements["l2_grads"] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if config.get("grad_clip_norm"):
      g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)

    decay_rules = config.get("weight_decay", []) or []
    if isinstance(decay_rules, numbers.Number):
      decay_rules = [(".*kernel.*", decay_rules)]
    sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr
    def decay_fn(v, wd):
      return (1.0 - sched_m * wd) * v
    opt = opt.replace(target=u.tree_map_with_regex(
        decay_fn, opt.target, decay_rules, name="weight decay"))

    params, _ = jax.tree_flatten(opt.target)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

    return opt, l, rng, measurements

  # Other things besides optimizer state to be stored.
  checkpoint_extra = dict(accum_train_time=0.0)

  # Decide how to initialize training. The order is important.
  # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
  # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  # 3. Initialize model from something, e,g, start a fine-tuning job.
  # 4. Train from scratch.
  resume_checkpoint_path = None
  if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path):
    resume_checkpoint_path = save_checkpoint_path
  elif config.get("resume"):
    resume_checkpoint_path = fillin(config.resume)
  if resume_checkpoint_path:
    write_note("Resume training from checkpoint...")
    checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
    _, checkpoint_tree = jax.tree_flatten(checkpoint)
    loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path)
    # bfloat16 type gets lost when data is saved to disk, so we recover it.
    checkpoint = jax.tree_map(u.recover_dtype, loaded)
    opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"]
  elif config.get("model_init"):
    write_note(f"Initialize model from {config.model_init}...")
    loaded = model_mod.load(params_cpu, config.model_init, config.get("model"))
    opt_cpu = opt_cpu.replace(target=loaded)
    if jax.host_id() == 0:
      logging.info("Restored parameter overview:")
      parameter_overview.log_parameter_overview(loaded)

  write_note("Kicking off misc stuff...")
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  chrono = u.Chrono(first_step, total_steps, batch_size,
                    checkpoint_extra["accum_train_time"])
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=workdir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = u.create_learning_rate_schedule(
      batch_size, total_steps, steps_per_epoch, **config.get("lr", {}))
  lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)),
                              config.get("prefetch_to_device", 1))

  write_note(f"Replicating...\n{chrono.note}")
  opt_repl = flax_utils.replicate(opt_cpu)

  write_note(f"Initializing few-shotters...\n{chrono.note}")
  if "fewshot" in config:
    fewshotter = fewshot.FewShotEvaluator(
        representation_fn, config.fewshot,
        config.fewshot.get("batch_size") or batch_size_eval)

  rng, rng_loop = jax.random.split(rng, 2)
  rngs_loop = flax_utils.replicate(rng_loop)
  checkpoint_writer = None

  write_note(f"First step compilations...\n{chrono.note}")
  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter):
    mw.step_start(step)

    with jax.profiler.TraceContext("train_step", step_num=step, _r=1):
      opt_repl, loss_value, rngs_loop, extra_measurements = update_fn(
          opt_repl,
          lr_repl,
          train_batch["image"],
          train_batch["labels"],
          rng=rngs_loop)

    if jax.host_id() == 0:
      profiler(step)

    # Checkpoint saving
    if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0):
      chrono.pause()
      u.checkpointing_timeout(checkpoint_writer,
                              config.get("checkpoint_timeout", 1))
      checkpoint_extra["accum_train_time"] = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps):
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
      checkpoint_writer = pool.apply_async(
          u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if u.itstime(step, config.log_training_steps, total_steps, host=0):
      mw.measure("learning_rate", lr_repl[0])
      mw.measure("training_loss", loss_value[0])
      for name, value in extra_measurements.items():
        mw.measure(name, value[0])
      chrono.tick(step, mw.measure, write_note)

    # Report validation performance
    if u.itstime(step, config.log_eval_steps, total_steps):
      chrono.pause()
      for val_name, (val_iter, val_steps) in val_ds.items():
        ncorrect, loss, nseen = 0, 0, 0
        for _, batch in zip(range(val_steps), val_iter):
          batch_ncorrect, batch_losses, batch_n = evaluation_fn(
              opt_repl.target, batch["image"], batch["labels"], batch["mask"])
          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.
          ncorrect += np.sum(np.array(batch_ncorrect[0]))
          loss += np.sum(np.array(batch_losses[0]))
          nseen += np.sum(np.array(batch_n[0]))
        mw.measure(f"{val_name}_prec@1", ncorrect / nseen)
        mw.measure(f"{val_name}_loss", loss / nseen)
      chrono.resume()

    if "fewshot" in config:
      # Compute few-shot on-the-fly evaluation.
      if u.itstime(step, config.fewshot.log_steps, total_steps):
        chrono.pause()
        write_note(f"Few-shot evaluation...\n{chrono.note}")
        r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets)
        fewshotter.walk_results(mw.measure, *r)
        chrono.resume()
    mw.step_end()

  write_note(f"Done!\n{chrono.note}")
  pool.close()
  pool.join()
  mw.close()