def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, eval_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.
    state = checkpoints.restore_checkpoint(workdir, state)
    initial_step = int(state.step) + 1

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = skmetrics.structural_similarity(
                        pred_color.astype(np.float32),
                        test_pixels.astype(np.float32),
                        win_size=11,
                        multichannel=True,
                        gaussian_weight=True)
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)
def main(config, output_dir):

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.host_id() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.host_count()
    local_batch_size_eval = batch_size_eval // jax.host_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.host_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))
    model = ub.models.bit_resnet(num_classes=config.num_classes,
                                 **config.get('model', {}))

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['head']['bias'] = jnp.full_like(params['head']['bias'],
                                               config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['head']['kernel'] = jnp.full_like(params['head']['kernel'],
                                                     0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.host_id() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)

        losses = getattr(train_utils,
                         config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        _, outputs = model.apply({'params': flax.core.freeze(params)},
                                 images,
                                 train=False)
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            accuracy = jnp.mean(
                jnp.equal(jnp.argmax(logits, axis=-1),
                          jnp.argmax(labels, axis=-1)))
            return getattr(train_utils,
                           config.get('loss',
                                      'sigmoid_xent'))(logits=logits,
                                                       labels=labels), accuracy

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (l, train_accuracy), g = grad_fn(opt.target, images, labels)
        l, g = jax.lax.pmean((l, g), axis_name='batch')
        measurements['accuracy'] = train_accuracy

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        decay_rules = config.get('weight_decay', []) or []
        if isinstance(decay_rules, numbers.Number):
            decay_rules = [('.*kernel.*', decay_rules)]
        sched_m = lr / config.lr.base if config.get(
            'weight_decay_decouple') else lr

        def decay_fn(v, wd):
            return (1.0 - sched_m * wd) * v

        opt = opt.replace(target=train_utils.tree_map_with_regex(
            decay_fn, opt.target, decay_rules))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements

    # Other things besides optimizer state to be stored.
    rng, rng_loop = jax.random.split(rng, 2)
    rngs_loop = flax_utils.replicate(rng_loop)
    checkpoint_extra = dict(accum_train_time=0.0, rngs_loop=rngs_loop)

    # Decide how to initialize training. The order is important.
    # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
    # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
    # 3. Initialize model from something, e,g, start a fine-tuning job.
    # 4. Train from scratch.
    resume_checkpoint_path = None
    if save_checkpoint_path and gfile.exists(save_checkpoint_path):
        resume_checkpoint_path = save_checkpoint_path
    elif config.get('resume'):
        resume_checkpoint_path = config.resume
    if resume_checkpoint_path:
        write_note('Resume training from checkpoint...')
        checkpoint_tree = {'opt': opt_cpu, 'extra': checkpoint_extra}
        checkpoint = checkpoint_utils.load_checkpoint(checkpoint_tree,
                                                      resume_checkpoint_path)
        opt_cpu, checkpoint_extra = checkpoint['opt'], checkpoint['extra']
        rngs_loop = checkpoint_extra['rngs_loop']
    elif config.get('model_init'):
        write_note(f'Initialize model from {config.model_init}...')
        reinit_params = config.get('model_reinit_params',
                                   ('head/kernel', 'head/bias'))
        logging.info('Reinitializing these parameters: %s', reinit_params)
        # We only support "no head" fine-tuning for now.
        loaded_params = checkpoint_utils.load_checkpoint(
            tree=None, path=config.model_init)
        loaded = checkpoint_utils.restore_from_pretrained_params(
            params_cpu,
            loaded_params,
            model_representation_size=None,
            model_classifier=None,
            reinit_params=reinit_params)
        opt_cpu = opt_cpu.replace(target=loaded)
        if jax.host_id() == 0:
            logging.info('Restored parameter overview:')
            parameter_overview.log_parameter_overview(loaded)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.host_id() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                checkpoint_extra['accum_train_time'])
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceContext('train_step', step_num=step, _r=1):
            opt_repl, loss_value, rngs_loop, extra_measurements = update_fn(
                opt_repl,
                lr_repl,
                train_batch['image'],
                train_batch['labels'],
                rng=rngs_loop)

        if jax.host_id() == 0:
            profiler(step)

        # Checkpoint saving
        if train_utils.itstime(step,
                               config.get('checkpoint_steps'),
                               total_steps,
                               process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            checkpoint_extra['accum_train_time'] = chrono.accum_train_time
            checkpoint_extra['rngs_loop'] = rngs_loop
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint = {'opt': opt_cpu, 'extra': checkpoint_extra}
            checkpoint_writer = pool.apply_async(
                checkpoint_utils.save_checkpoint,
                (checkpoint, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if train_utils.itstime(step,
                               config.log_training_steps,
                               total_steps,
                               process=0):
            write_note('Reporting training progress...')
            train_accuracy = extra_measurements['accuracy']
            train_accuracy = jnp.mean(train_accuracy)
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
                'training_accuracy': train_accuracy,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, val_ds in val_ds_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                val_iter = input_utils.start_input_pipeline(
                    val_ds, config.get('prefetch_to_device', 1))
                ncorrect, loss, nseen = 0, 0, 0
                for batch in val_iter:
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, batch['image'],
                                          batch['labels'], batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            # TODO(jereliu): Extend to support soft multi-class probabilities.
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name],
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # Entries in the ood_ds dict include:
            # (ind_dataset, ood_dataset1, ood_dataset2, ...).
            # OOD metrics are computed using ind_dataset paired with each of the
            # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
            # is also included in the ood_ds.
            if ood_ds and config.ood_methods:
                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds,
                    ood_ds_names,
                    config.ood_methods,
                    evaluation_fn,
                    opt_repl.target,
                    n_prefetch=config.get('prefetch_to_device', 1))
                writer.write_scalars(step, ood_measurements)
            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target, config.fewshot.datasets)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
Beispiel #3
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train, ds_test = input_pipeline.get_datasets(config)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  model_or_filename = config.get('model_or_filename')
  if model_or_filename:
    # Loading model from repo published with  "How to train your ViT? Data,
    # Augmentation, and Regularization in Vision Transformers" paper.
    # https://arxiv.org/abs/2106.10270
    if '-' in model_or_filename:
      filename = model_or_filename
    else:
      # Select best checkpoint from i21k pretraining by final upstream
      # validation accuracy.
      df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
      sel = df.filename.apply(
          lambda filename: filename.split('-')[0] == model_or_filename)
      best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
      filename = best.filename
      logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
                   filename, model_or_filename, best.final_val)
    pretrained_path = os.path.join(config.pretrained_dir,
                                   f'{config.model.model_name}.npz')
  else:
    # ViT / Mixer papers
    filename = config.model.model_name

  pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)

  initial_step = 1
  opt, initial_step = flax_checkpoints.restore_checkpoint(
      workdir, (opt, initial_step))
  logging.info('Will start/continue training at initial_step=%d', initial_step)

  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Setup metric writer & hooks.
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  hooks = [
      periodic_actions.Profile(logdir=workdir),
      periodic_actions.ReportProgress(
          num_train_steps=total_steps, writer=writer),
  ]

  # Run training loop
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()
  lstep = initial_step
  for step, batch in zip(
      range(initial_step, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    with jax.profiler.StepTraceAnnotation('train', step_num=step):
      opt_repl, loss_repl, update_rng_repl = update_fn_repl(
          opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    for hook in hooks:
      hook(step)

    if step == initial_step:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

    # Store checkpoint.
    if ((config.checkpoint_every and step % config.eval_every == 0) or
        step == total_steps):
      checkpoint_path = flax_checkpoints.save_checkpoint(
          workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
      logging.info('Stored checkpoint at step %d to "%s"', step,
                   checkpoint_path)

  return flax.jax_utils.unreplicate(opt_repl)
Beispiel #4
0
def main(args):
  logdir = os.path.join(args.logdir, args.name)
  logger = logging.setup_logger(logdir)
  logger.info(args)

  logger.info(f'Available devices: {jax.devices()}')

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train')

  ds_train = input_pipeline.get_data(
      dataset=args.dataset,
      mode='train',
      repeats=None,
      mixup_alpha=args.mixup_alpha,
      batch_size=args.batch,
      shuffle_buffer=args.shuffle_buffer,
      tfds_data_dir=args.tfds_data_dir,
      tfds_manual_dir=args.tfds_manual_dir)
  batch = next(iter(ds_train))
  logger.info(ds_train)
  ds_test = input_pipeline.get_data(
      dataset=args.dataset,
      mode='test',
      repeats=1,
      batch_size=args.batch_eval,
      tfds_data_dir=args.tfds_data_dir,
      tfds_manual_dir=args.tfds_manual_dir)
  logger.info(ds_test)

  # Build VisionTransformer architecture
  model = models.KNOWN_MODELS[args.model]
  VisionTransformer = model.partial(num_classes=dataset_info['num_classes'])
  _, params = VisionTransformer.init_by_shape(
      jax.random.PRNGKey(0),
      # Discard the "num_local_devices" dimension for initialization.
      [(batch['image'].shape[1:], batch['image'].dtype.name)])

  pretrained_path = os.path.join(args.vit_pretrained_dir, f'{args.model}.npz')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=params,
      model_config=models.CONFIGS[args.model],
      logger=logger)

  # pmap replicates the models over all TPUs/GPUs
  vit_fn_repl = jax.pmap(VisionTransformer.call)
  update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps)

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=args.optim_dtype, grad_norm_clip=args.grad_norm_clip).create(params)
  opt_repl = flax_utils.replicate(opt)

  # Delete referenes to the objects that are not needed anymore
  del opt
  del params

  def copyfiles(paths):
    """Small helper to copy files to args.copy_to using tf.io.gfile."""
    if not args.copy_to:
      return
    for path in paths:
      to_path = os.path.join(args.copy_to, args.name, os.path.basename(path))
      tf.io.gfile.makedirs(os.path.dirname(to_path))
      tf.io.gfile.copy(path, to_path, overwrite=True)
      logger.info(f'Copied {path} to {to_path}.')

  total_steps = args.total_steps or (
      input_pipeline.DATASET_PRESETS[args.dataset]['total_steps'])

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr,
                                              args.decay_type,
                                              args.warmup_steps)
  lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps)
  update_rngs = jax.random.split(
      jax.random.PRNGKey(0), jax.local_device_count())

  # Run training loop
  writer = metric_writers.create_default_writer(logdir, asynchronous=False)
  writer.write_hparams({k: v for k, v in vars(args).items() if v is not None})
  logger.info('Starting training loop; initial compile can take a while...')
  t0 = time.time()

  for step, batch, lr_repl in zip(
      range(1, total_steps + 1),
      input_pipeline.prefetch(ds_train, args.prefetch), lr_iter):

    opt_repl, loss_repl, update_rngs = update_fn_repl(
        opt_repl, lr_repl, batch, update_rngs)

    if step == 1:
      logger.info(f'First step took {time.time() - t0:.1f} seconds.')
      t0 = time.time()
    if args.progress_every and step % args.progress_every == 0:
      writer.write_scalars(step, dict(train_loss=float(loss_repl[0])))
      done = step / total_steps
      logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '
                  f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')
      copyfiles(glob.glob(f'{logdir}/*'))

    # Run eval step
    if ((args.eval_every and step % args.eval_every == 0) or
        (step == total_steps)):

      accuracy_test = np.mean([
          c for batch in input_pipeline.prefetch(ds_test, args.prefetch)
          for c in (
              np.argmax(vit_fn_repl(opt_repl.target, batch['image']),
                        axis=2) == np.argmax(batch['label'], axis=2)).ravel()
      ])

      lr = float(lr_repl[0])
      logger.info(f'Step: {step} '
                  f'Learning rate: {lr:.7f}, '
                  f'Test accuracy: {accuracy_test:0.5f}')
      writer.write_scalars(step, dict(accuracy_test=accuracy_test, lr=lr))
      copyfiles(glob.glob(f'{logdir}/*'))

  if args.output:
    checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output)
    logger.info(f'Stored fine tuned checkpoint to {args.output}')
    copyfiles([args.output])
Beispiel #5
0
def train(config, model_def, device_batch_size, eval_ds, num_steps,
          steps_per_epoch, steps_per_eval, train_ds, image_size, data_source,
          workdir):
  """Train model."""

  make_lr_fn = schedulers.get_make_lr_fn(config)
  make_temp_fn = schedulers.get_make_temp_fn(config)
  make_step_size_fn = schedulers.get_make_step_size_fn(config)
  if jax.host_count() > 1:
    raise ValueError('CIFAR10 example should not be run on '
                     'more than 1 host due to preconditioner updating.')

  initial_step = 0  # TODO(basv): load from checkpoint.
  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)

  # Write config to the summary files. This makes the hyperparameters available
  # in TensorBoard and makes comparison of runs in TensorBoard easier.
  # with writer.summary_writer.as_default():
  writer.write_hparams(dict(config))

  rng = random.PRNGKey(config.seed)
  rng, opt_rng, init_key, sampler_rng = jax.random.split(rng, 4)

  base_learning_rate = config.learning_rate

  # Create the model.
  model, state = create_model(rng, device_batch_size, image_size, model_def)
  parameter_overview.log_parameter_overview(model.params)
  state = jax_utils.replicate(state)

  train_size = data_source.TRAIN_IMAGES

  with flax.deprecated.nn.stochastic(init_key):
    optimizer = create_optimizer(config, model, base_learning_rate, train_size,
                                 sampler_rng)
  del model  # Don't keep a copy of the initial model.

  # Learning rate schedule
  learning_rate_fn = make_lr_fn(base_learning_rate, steps_per_epoch)
  temperature_fn = make_temp_fn(config.base_temp, steps_per_epoch)
  step_size_fn = make_step_size_fn(steps_per_epoch)

  p_eval_step, _, p_train_step, p_update_grad_vars = make_step_functions(
      config, config.l2_reg, learning_rate_fn, train_size, temperature_fn,
      step_size_fn)

  # Create dataset batch iterators.
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  # Gather metrics.
  train_metrics = []
  epoch = 0

  # Ensemble.
  ensemble = []
  ensemble_logits = []
  ensemble_labels = []
  ensemble_probs = []

  def ensemble_add_step(step):
    if config.lr_schedule == 'cosine':
      # Add if learning rate jumps up again in the next step.
      increase = step_size_fn(step) < step_size_fn(step + 1) - 1e-8
      _, temp_end = ast.literal_eval(config.temp_ramp)
      past_burn_in = step >= steps_per_epoch * temp_end
      return increase and past_burn_in

    elif config.lr_schedule == 'constant':
      if (step + 1) % steps_per_epoch == 0:
        return True
    return False

  logging.info('Starting training loop at step %d.', initial_step)

  for step in range(initial_step, num_steps):
    if config.optimizer in ['sym_euler'] and (step) % steps_per_epoch == 0:
      optimizer, rng = update_preconditioner(config, optimizer,
                                             p_update_grad_vars, rng, state,
                                             train_iter)
    # Generate a PRNG key that will be rolled into the batch
    step_key = jax.random.fold_in(rng, step)
    opt_step_rng = jax.random.fold_in(opt_rng, step)

    # Load and shard the TF batch
    batch = next(train_iter)
    batch = input_pipeline.load_and_shard_tf_batch(config, batch)
    if not config.debug_run:
      # Shard the step PRNG key
      # Don't shard the optimizer rng, as it should be equal among all machines.
      sharded_keys = common_utils.shard_prng_key(step_key)
    else:
      sharded_keys = step_key

    # Train step
    optimizer, state, metrics = p_train_step(optimizer, state, batch,
                                             sharded_keys, opt_step_rng)
    train_metrics.append(metrics)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
    if step == initial_step:
      initial_train_metrics = get_metrics(config, train_metrics)
      train_summary = jax.tree_map(lambda x: x.mean(), initial_train_metrics)
      train_summary = {'train_' + k: v for k, v in train_summary.items()}
      logging.log(logging.INFO, 'initial metrics = %s',
                  str(train_summary.items()))

    if (step + 1) % steps_per_epoch == 0:
      # We've finished an epoch
      # Save model params/state.

      train_metrics = get_metrics(config, train_metrics)
      # Get training epoch summary for logging
      train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)

      train_summary = {'train_' + k: v for k, v in train_summary.items()}

      writer.write_scalars(epoch, train_summary)
      # Reset train metrics
      train_metrics = []

      # Evaluation
      if config.do_eval:
        eval_metrics = []
        eval_logits = []
        eval_labels = []
        for _ in range(steps_per_eval):
          eval_batch = next(eval_iter)
          # Load and shard the TF batch
          eval_batch = input_pipeline.load_and_shard_tf_batch(
              config, eval_batch)
          # Step
          logits, labels, metrics = p_eval_step(optimizer.target, state,
                                                eval_batch)
          eval_metrics.append(metrics)
          eval_logits.append(logits)
          eval_labels.append(labels)
        eval_metrics = get_metrics(config, eval_metrics)
        # Get eval epoch summary for logging
        eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
        eval_summary = {'eval_' + k: v for k, v in eval_summary.items()}
        writer.write_scalars(epoch, eval_summary)

      if config.algorithm == 'sgmcmc' and ensemble_add_step(step):
        ensemble.append((serialization.to_state_dict(optimizer.target), state))

      if config.algorithm == 'sgmcmc' and ensemble_add_step(
          step) and len(ensemble) >= 1:
        # Gather predictions for this ensemble sample.
        eval_logits = jnp.concatenate(eval_logits, axis=0)
        eval_probs = jax.nn.softmax(eval_logits, axis=-1)
        eval_labels = jnp.concatenate(eval_labels, axis=0)
        # Ensure that labels are consistent between predict runs.
        if ensemble_labels:
          assert jnp.allclose(
              eval_labels,
              ensemble_labels[0]), 'Labels unordered between eval runs.'

        ensemble_logits.append(eval_logits)
        ensemble_probs.append(eval_probs)
        ensemble_labels.append(eval_labels)

        # Compute ensemble predictions over last config.ensemble_size samples.
        ensemble_last_probs = jnp.mean(
            jnp.array(ensemble_probs[-config.ensemble_size:]), axis=0)
        ensemble_metrics = train_functions.compute_metrics_probs(
            ensemble_last_probs, ensemble_labels[0])
        ensemble_summary = jax.tree_map(lambda x: x.mean(), ensemble_metrics)
        ensemble_summary = {'ens_' + k: v for k, v in ensemble_summary.items()}
        ensemble_summary['ensemble_size'] = min(config.ensemble_size,
                                                len(ensemble_probs))
        writer.write_scalars(epoch, ensemble_summary)

      epoch += 1

  return ensemble, optimizer
Beispiel #6
0
def main(argv):
  del argv  # unused arg

  config = FLAGS.config

  # Unpack total and warmup steps
  # TODO(nband): revert this to separate arguments.
  total_steps = config.total_and_warmup_steps[0]
  warmup_steps = config.total_and_warmup_steps[1]
  del config.total_and_warmup_steps
  config.total_steps = total_steps
  config.lr.warmup_steps = warmup_steps

  # Wandb and Checkpointing Setup
  output_dir = FLAGS.output_dir
  wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
  tf.io.gfile.makedirs(output_dir)
  logging.info('Saving checkpoints at %s', output_dir)

  # Dataset Split Flags
  dist_shift = config.distribution_shift
  print(f'Distribution Shift: {dist_shift}.')
  dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift)

  # LR / Optimization Flags
  batch_size = config.batch_size
  grad_clip_norm = config.grad_clip_norm
  weight_decay = config.weight_decay
  print('Standard wandb hyperparameters:')
  print({
      'batch_size': batch_size,
      'grad_clip_norm': grad_clip_norm,
      'weight_decay': weight_decay,
      'total_steps': config.total_steps,
      'lr': config.lr
  })
  print('SNGP Params:', config.gp_layer)

  # Reweighting loss for class imbalance
  # class_reweight_mode = config.class_reweight_mode
  # if class_reweight_mode == 'constant':
  #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  # else:
  #   class_weights = None

  # Shows the number of available devices.
  # In a CPU/GPU runtime this will be a single device.
  # In a TPU runtime this will be 8 cores.
  print('Number of Jax local devices:', jax.local_devices())

  # TODO(nband): fix sigmoid loss issues.
  assert config.get('loss', None) == 'softmax_xent'

  seed = config.seed
  rng = jax.random.PRNGKey(seed)
  tf.random.set_seed(seed)

  if config.get('data_dir'):
    logging.info('data_dir=%s', config.data_dir)
  logging.info('Output dir: %s', output_dir)

  save_checkpoint_path = None
  if config.get('checkpoint_steps'):
    tf.io.gfile.makedirs(output_dir)
    save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

  # Create an asynchronous multi-metric writer.
  writer = metric_writers.create_default_writer(
      output_dir, just_logging=jax.process_index() > 0)

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  def write_note(note):
    if jax.process_index() == 0:
      logging.info('NOTE: %s', note)

  write_note('Initializing...')

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get('keep_checkpoint_steps'):
    assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
        f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

  batch_size_eval = config.get('batch_size_eval', batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must '
                     f'be divisible by device number ({jax.device_count()})')

  local_batch_size = batch_size // jax.process_count()
  local_batch_size_eval = batch_size_eval // jax.process_count()
  logging.info(
      'Global batch size %d on %d hosts results in %d local batch size. '
      'With %d dev per host (%d dev total), that is a %d per-device batch size.',
      batch_size,
      jax.process_count(), local_batch_size, jax.local_device_count(),
      jax.device_count(), local_batch_size // jax.local_device_count())

  write_note('Initializing preprocessing function...')
  # Same preprocessing function for training and evaluation
  preproc_fn = preprocess_spec.parse(
      spec=config.pp_train, available_ops=preprocess_utils.all_ops())

  write_note('Initializing train dataset...')
  rng, train_ds_rng = jax.random.split(rng)
  train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
  train_base_dataset = ub.datasets.get(
      dataset_names['in_domain_dataset'],
      split=split_names['train_split'],
      data_dir=config.get('data_dir'))
  train_dataset_builder = train_base_dataset._dataset_builder  # pylint: disable=protected-access
  train_ds = input_utils.get_data(
      dataset=train_dataset_builder,
      split=split_names['train_split'],
      rng=train_ds_rng,
      process_batch_size=local_batch_size,
      preprocess_fn=preproc_fn,
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch_size=config.get('prefetch_to_host', 2),
      data_dir=config.get('data_dir'))
  logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:])

  # Start prefetching already.
  train_iter = input_utils.start_input_pipeline(
      train_ds, config.get('prefetch_to_device', 1))

  write_note('Initializing val dataset(s)...')

  # Load in-domain and OOD validation and/or test datasets.
  # Please specify the desired shift (Country Shift or Severity Shift)
  # in the config.
  eval_iter_splits = vit_utils.init_evaluation_datasets(
      use_validation=config.use_validation,
      use_test=config.use_test,
      dataset_names=dataset_names,
      split_names=split_names,
      config=config,
      preproc_fn=preproc_fn,
      batch_size_eval=batch_size_eval,
      local_batch_size_eval=local_batch_size_eval)

  ntrain_img = input_utils.get_num_examples(
      train_dataset_builder,
      split=split_names['train_split'],
      process_batch_size=local_batch_size,
      data_dir=config.get('data_dir'))
  steps_per_epoch = ntrain_img / batch_size

  if config.get('num_epochs'):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get('total_steps'), 'Set either num_epochs or total_steps'
  else:
    total_steps = config.total_steps

  logging.info('Total train data points: %d', ntrain_img)
  logging.info(
      'Running for %d steps, that means %f epochs and %d steps per epoch',
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

  write_note('Initializing model...')
  logging.info('config.model = %s', config.get('model'))

  # Specify Gaussian process layer configs.
  gp_config = config.get('gp_layer', {})
  model_dict = vit_utils.initialize_model('sngp', config)
  model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer']

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    image_size = tuple(train_ds.element_spec['image'].shape[2:])
    logging.info('image_size = %s', image_size)
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    variables = model.init(rng, dummy_input, train=False)
    # Split model parameters into trainable and untrainable collections.
    states, params = variables.pop('params')
    del variables

    # Set bias in the head to a low value, such that loss is small initially.
    params = flax.core.unfreeze(params)
    if use_gp_layer:
      # Modify the head parameter in the GP head.
      params['head']['output_layer']['bias'] = jnp.full_like(
          params['head']['output_layer']['bias'],
          config.get('init_head_bias', 0))
    else:
      params['head']['bias'] = jnp.full_like(
          params['head']['bias'], config.get('init_head_bias', 0))

    return params, states

  rng, rng_init = jax.random.split(rng)
  params_cpu, states_cpu = init(rng_init)

  if jax.process_index() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    writer.write_scalars(step=0, scalars={'num_params': num_params})

  @functools.partial(jax.pmap, axis_name='batch')
  def evaluation_fn(params, states, images, labels):
    variable_dict = {'params': flax.core.freeze(params), **states}
    logits, out = model.apply(
        variable_dict,
        images,
        train=False,
        mean_field_factor=gp_config.get('mean_field_factor', -1.))
    losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses, axis_name='batch')
    top1_idx = jnp.argmax(logits, axis=1)

    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]

    ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
    n = batch_size_eval
    metric_args = jax.lax.all_gather([
        logits, labels, out['pre_logits']], axis_name='batch')
    return ncorrect, loss, n, metric_args

  # Load the optimizer from flax.
  opt_name = config.get('optim_name')
  write_note(f'Initializing {opt_name} optimizer...')
  opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  weight_decay_rules = config.get('weight_decay', []) or []
  rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1.
  weight_decay_fn = train_utils.get_weight_decay_fn(
      weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

  @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
  def update_fn(opt, states, lr, reset_covmat, images, labels, rng):
    """Update step."""
    measurements = {}

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch'))

    def loss_fn(params, states, images, labels):
      # Specify mutable collection to update untrainable GP parameters.
      variable_dict = {'params': flax.core.freeze(params), **states}
      model_results, updated_states = model.apply(
          variable_dict,
          images,
          train=True,
          rngs={'dropout': rng_model_local},
          mutable=list(states.keys()),
          mean_field_factor=gp_config.get('mean_field_factor', -1.))

      logits, _ = model_results
      loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
          logits=logits, labels=labels)
      return loss, updated_states

    # Performs exact covariance update (i.e., reset precision matrix resetting
    # at begining of new epoch) if covmat_momentum is a null value.
    if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0:
      # Resets precision matrix to Identity * ridge_penalty if at the begining
      # of a new epoch. This should be done before accumulate gradient.
      ridge_penalty = gp_config.get('ridge_penalty', 1.)
      prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix']
      prec_mat_new = (
          (1. - reset_covmat) * prec_mat_old +
          reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty)

      states = flax.core.unfreeze(states)
      states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix'] = prec_mat_new
      states = flax.core.freeze(states)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    (l, s), g = vit_utils.accumulate_gradient_with_states(
        jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images,
        labels, config.get('grad_accum_steps'))
    l, g = jax.lax.pmean((l, g), axis_name='batch')

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None:
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements['l2_grads'] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if grad_clip_norm is not None:
      g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)
    opt = opt.replace(target=weight_decay_fn(opt.target, lr))

    params, _ = jax.tree_flatten(opt.target)
    measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))
    measurements['reset_covmat'] = reset_covmat

    return opt, s, l, rng, measurements

  # Set config checkpoint resume path, if provided in args.
  if config.resume_checkpoint_path is not None:
    config.resume = config.resume_checkpoint_path

  default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias',
                           'head/kernel', 'head/bias')
  rng, train_loop_rngs = jax.random.split(rng)
  checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
      train_loop_rngs=train_loop_rngs,
      save_checkpoint_path=save_checkpoint_path,
      init_optimizer=opt_cpu,
      init_params=params_cpu,
      init_fixed_model_states=states_cpu,
      default_reinit_params=default_reinit_params,
      config=config)
  train_loop_rngs = checkpoint_data.train_loop_rngs
  opt_cpu = checkpoint_data.optimizer
  states_cpu = checkpoint_data.fixed_model_states
  accumulated_train_time = checkpoint_data.accumulated_train_time

  write_note('Adapting the checkpoint model...')
  adapted_params = checkpoint_utils.adapt_upstream_architecture(
      init_params=params_cpu,
      loaded_params=opt_cpu.target)
  opt_cpu = opt_cpu.replace(target=adapted_params)

  write_note('Kicking off misc stuff...')
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  if first_step == 0 and jax.process_index() == 0:
    writer.write_hparams(dict(config))
  chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                              accumulated_train_time)
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=output_dir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                    **config.get('lr', {}))

  # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
  # necessary for TPUs.
  lr_iter = train_utils.prefetch_scalar(
      map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1))

  # Prepare the precision matrix resetting schedule, and pre-fetch it to device.
  reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0)
  reset_covmat_iter = train_utils.prefetch_scalar(
      map(reset_covmat_fn, range(first_step, total_steps)),
      nprefetch=config.get('prefetch_to_device', 1))

  write_note(f'Replicating...\n{chrono.note}')
  opt_repl = flax.jax_utils.replicate(opt_cpu)
  states_repl = flax.jax_utils.replicate(states_cpu)

  checkpoint_writer = None

  # Note: we return the train loss, val loss, and fewshot best l2s for use in
  # reproducibility unit tests.
  # train_loss = -jnp.inf
  # val_loss = -jnp.inf
  # results = {'dummy': {(0, 1): -jnp.inf}}

  write_note(f'First step compilations...\n{chrono.note}')
  logging.info('first_step = %s', first_step)
  # Advance the iterators if we are restarting from an earlier checkpoint.
  # TODO(dusenberrymw): Look into checkpointing dataset state instead.

  # Makes sure log_eval_steps is same as steps_per_epoch. This is because
  # the precision matrix needs to be updated fully (at the end of each epoch)
  # when eval takes place.
  log_eval_steps = steps_per_epoch
  if first_step > 0:
    write_note('Advancing iterators after resuming from a checkpoint...')
    lr_iter = itertools.islice(lr_iter, first_step, None)
    train_iter = itertools.islice(train_iter, first_step, None)

  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl, reset_covmat_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter,
      reset_covmat_iter):

    with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
      # TODO(jereliu): Expand to allow precision matrix resetting.
      (opt_repl, states_repl, loss_value, train_loop_rngs,
       extra_measurements) = update_fn(
           opt_repl,
           states_repl,
           lr_repl,
           reset_covmat_repl,
           train_batch['image'],
           train_batch['labels'],
           rng=train_loop_rngs)

    if jax.process_index() == 0:
      profiler(step)

    # Checkpoint saving
    if train_utils.itstime(
        step, config.get('checkpoint_steps'), total_steps, process=0):
      write_note('Checkpointing...')
      chrono.pause()
      train_utils.checkpointing_timeout(checkpoint_writer,
                                        config.get('checkpoint_timeout', 1))
      accumulated_train_time = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      # For GP layer, we will also do the same for untrainable parameters
      # (`states`). This is ok since `random features` are frozen throughout
      # pre-training, and `precision matrix` is a finetuning-specific parameters
      # that will be re-learned in the finetuning task.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
      states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                             total_steps):
        write_note('Keeping a checkpoint copy...')
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint_data = checkpoint_utils.CheckpointData(
          optimizer=opt_cpu,
          fixed_model_states=states_cpu,
          train_loop_rngs=train_loop_rngs,
          accumulated_train_time=accumulated_train_time)
      checkpoint_writer = pool.apply_async(
          checkpoint_utils.checkpoint_trained_model,
          (checkpoint_data, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if train_utils.itstime(
        step, config.log_training_steps, total_steps, process=0):
      write_note('Reporting training progress...')
      train_loss = loss_value[0]  # Keep to return for reproducibility tests.
      timing_measurements, note = chrono.tick(step)
      write_note(note)
      train_measurements = {}
      train_measurements.update({
          'learning_rate': lr_repl[0],
          'training_loss': train_loss,
      })
      train_measurements.update(flax.jax_utils.unreplicate(extra_measurements))
      train_measurements.update(timing_measurements)
      writer.write_scalars(step, train_measurements)

    # Report validation performance
    if train_utils.itstime(step, log_eval_steps, total_steps):
      write_note('Evaluating on the validation set...')
      chrono.pause()

      all_eval_results = {}

      for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
        start_time = time.time()

        # Runs evaluation loop.
        results_arrs = {
            'y_true': [],
            'y_pred': [],
            'y_pred_entropy': []
        }

        for _, batch in zip(range(eval_steps), eval_iter):
          batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
              evaluation_fn(
                  opt_repl.target, states_repl, batch['image'],
                  batch['labels']))

          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.

          # Here we parse batch_metric_args to compute uncertainty metrics.
          logits, labels, _ = batch_metric_args
          logits = np.array(logits[0])
          probs = jax.nn.softmax(logits)

          # From one-hot to integer labels.
          int_labels = np.argmax(np.array(labels[0]), axis=-1)

          probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1))
          int_labels = int_labels.flatten()
          y_pred = probs[:, 1]
          results_arrs['y_true'].append(int_labels)
          results_arrs['y_pred'].append(y_pred)

          # Entropy is computed at the per-epoch level (see below).
          results_arrs['y_pred_entropy'].append(probs)

        results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                axis=0)
        results_arrs['y_pred'] = np.concatenate(
            results_arrs['y_pred'], axis=0).astype('float64')
        results_arrs['y_pred_entropy'] = vit_utils.entropy(
            np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1)

        time_elapsed = time.time() - start_time
        results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
        results_arrs['dataset_size'] = eval_steps * batch_size_eval

        all_eval_results[eval_name] = results_arrs

      per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
          dataset_split_to_containers=all_eval_results,
          is_deterministic=True,
          num_bins=15,
          return_per_pred_results=True
      )

      # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
      # dataset. Flatten this dict so we can pass to the writer and remove empty
      # entries.
      flattened_metric_results = {}
      for dic in metrics_results.values():
        for key, value in dic.items():
          if value is not None:
            flattened_metric_results[key] = value
      writer.write_scalars(step, flattened_metric_results)

      # Optionally log to wandb
      if config.use_wandb:
        wandb.log(metrics_results, step=step)

      # Save per-prediction metrics
      results_storage_utils.save_per_prediction_results(
          output_dir, step, per_pred_results, verbose=False)

      chrono.resume()

      # End of step.
    if config.get('testing_failure_step'):
      # Break early to simulate infra failures in test cases.
      if config.testing_failure_step == step:
        break

  write_note(f'Done!\n{chrono.note}')
  pool.close()
  pool.join()
  writer.close()

  if wandb_run is not None:
    wandb_run.finish()
Beispiel #7
0
def train(*,
          workdir,
          compute_phi,
          compute_psi,
          params,
          optimal_subspace,
          num_epochs,
          learning_rate,
          key,
          method,
          lissa_kappa,
          optimizer,
          covariance_batch_size,
          main_batch_size,
          weight_batch_size,
          d,
          num_tasks,
          compute_feature_norm_on_oracle_states,
          sample_states,
          eval_states,
          use_tabular_gradient=True):
    """Training function.

  For lissa, the total number of samples is
  2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size.

  Args:
    workdir: Work directory, where we'll save logs.
    compute_phi: A function that takes params and states and returns
      a matrix of phis.
    compute_psi: A function that takes an array of states and an array
      of tasks and returns Psi[states, tasks].
    params: Parameters used as the first argument for compute_phi.
    optimal_subspace: Top-d left singular vectors of Psi.
    num_epochs: How many gradient steps to perform. (Not really epochs)
    learning_rate: The step size parameter for sgd.
    key: The jax prng key.
    method: 'naive', 'lissa', or 'oracle'.
    lissa_kappa: The parameter of the lissa method, if used.
    optimizer: Which optimizer to use. Only 'sgd' is supported.
    covariance_batch_size: the 'J' parameter. For the naive method, this is how
      many states we sample to construct the inverse. For the lissa method,
      ditto -- these are also "iterations".
    main_batch_size: How many states to update at once.
    weight_batch_size: How many states to construct the weight vector.
    d: The dimension of the representation.
    num_tasks: The total number of tasks.
    compute_feature_norm_on_oracle_states: If True, computes the feature norm
      using the oracle states (all the states in synthetic experiments).
      Otherwise, computes the norm using the sampled batch.
      Only applies to LISSA.
    sample_states: A function that takes an rng key and a number of states
      to sample, and returns a tuple containing
      (a vector of sampled states, an updated rng key).
    eval_states: An array of states to use to compute metrics on.
      This will be used to compute Phi = compute_phi(params, eval_states).
    use_tabular_gradient: If true, the train step will calculate the
      gradient using the tabular calculation. Otherwise, it will use a
      jax.vjp to backpropagate the gradient.
  """
    # Create an explicit weight vector (needed for explicit method only).
    if method == 'explicit':
        key, weight_key = jax.random.split(key)
        explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks),
                                                   dtype=jnp.float32)
        params['explicit_weight_matrix'] = explicit_weight_matrix

    if optimizer == 'sgd':
        optimizer = optax.sgd(learning_rate)
    elif optimizer == 'adam':
        optimizer = optax.adam(learning_rate)
    else:
        raise ValueError(f'Unknown optimizer {optimizer}.')
    optimizer_state = optimizer.init(params)

    chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value)
    initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize(
        (0, params, optimizer_state))

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    # Checkpointing and logging too much can use a lot of disk space.
    # Therefore, we don't want to checkpoint more than 10 times an experiment,
    # or keep more than 1k Phis per experiment.
    checkpoint_period = max(num_epochs // 10, 100_000)
    log_period = max(1_000, num_epochs // 1_000)

    def _checkpoint_callback(step, t, params, optimizer_state):
        del t  # Unused.
        chkpt_manager.save((step, params, optimizer_state))

    hooks = [
        periodic_actions.PeriodicCallback(every_steps=checkpoint_period,
                                          callback_fn=_checkpoint_callback)
    ]

    fixed_train_kwargs = {
        'compute_phi':
        compute_phi,
        'compute_psi':
        compute_psi,
        'optimizer':
        optimizer,
        'method':
        method,
        # In the tabular case, the eval_states are all the states.
        'oracle_states':
        eval_states,
        'lissa_kappa':
        lissa_kappa,
        'main_batch_size':
        main_batch_size,
        'covariance_batch_size':
        covariance_batch_size,
        'weight_batch_size':
        weight_batch_size,
        'd':
        d,
        'num_tasks':
        num_tasks,
        'compute_feature_norm_on_oracle_states':
        (compute_feature_norm_on_oracle_states),
        'sample_states':
        sample_states,
        'use_tabular_gradient':
        use_tabular_gradient,
    }
    variable_kwargs = {
        'params': params,
        'optimizer_state': optimizer_state,
        'key': key,
    }

    @jax.jit
    def _eval_step(phi_params):
        eval_phi = compute_phi(phi_params, eval_states)
        eval_psi = compute_psi(eval_states)  # pytype: disable=wrong-arg-count

        metrics = compute_metrics(eval_phi, optimal_subspace)
        metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)}
        return metrics

    # Perform num_epochs gradient steps.
    with metric_writers.ensure_flushes(writer):
        for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1),
                               initial=initial_step,
                               total=num_epochs):

            variable_kwargs = _train_step(**fixed_train_kwargs,
                                          **variable_kwargs)

            if step % log_period == 0:
                metrics = _eval_step(variable_kwargs['params']['phi_params'])
                writer.write_scalars(step, metrics)

            for hook in hooks:
                hook(step,
                     params=variable_kwargs['params'],
                     optimizer_state=variable_kwargs['optimizer_state'])

    writer.flush()
Beispiel #8
0
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = pyspiel.load_game(FLAGS.game_name,
                             GAME_SETTINGS.get(FLAGS.game_name, {}))
    uniform_policy = policy.UniformRandomPolicy(game)
    mfg_dist = distribution.DistributionPolicy(game, uniform_policy)

    envs = [
        rl_environment.Environment(game,
                                   mfg_distribution=mfg_dist,
                                   mfg_population=p)
        for p in range(game.num_players())
    ]
    info_state_size = envs[0].observation_spec()["info_state"][0]
    num_actions = envs[0].action_spec()["num_actions"]

    hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
    kwargs = {
        "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
        "min_buffer_size_to_learn": FLAGS.min_buffer_size_to_learn,
        "batch_size": FLAGS.batch_size,
        "learn_every": FLAGS.learn_every,
        "learning_rate": FLAGS.rl_learning_rate,
        "optimizer_str": FLAGS.optimizer_str,
        "loss_str": FLAGS.loss_str,
        "update_target_network_every": FLAGS.update_target_network_every,
        "discount_factor": FLAGS.discount_factor,
        "epsilon_decay_duration": FLAGS.epsilon_decay_duration,
        "epsilon_start": FLAGS.epsilon_start,
        "epsilon_end": FLAGS.epsilon_end,
    }

    # pylint: disable=g-complex-comprehension
    agents = [
        dqn.DQN(idx, info_state_size, num_actions, hidden_layers_sizes,
                **kwargs) for idx in range(game.num_players())
    ]
    joint_avg_policy = rl_agent_policy.JointRLAgentPolicy(
        game, {idx: agent
               for idx, agent in enumerate(agents)}, envs[0].use_observation)
    if FLAGS.use_checkpoints:
        for agent in agents:
            if agent.has_checkpoint(FLAGS.checkpoint_dir):
                agent.restore(FLAGS.checkpoint_dir)

    # Metrics writer will also log the metrics to stderr.
    just_logging = FLAGS.logdir is None or jax.host_id() > 0
    writer = metric_writers.create_default_writer(FLAGS.logdir,
                                                  just_logging=just_logging)

    # Save the parameters.
    writer.write_hparams(kwargs)

    for ep in range(1, FLAGS.num_train_episodes + 1):
        if ep % FLAGS.eval_every == 0:
            writer.write_scalars(
                ep, {
                    f"agent{i}/loss": float(agent.loss)
                    for i, agent in enumerate(agents)
                })

            initial_states = game.new_initial_states()

            # Exact best response to uniform.
            nash_conv_obj = nash_conv.NashConv(game, uniform_policy)
            writer.write_scalars(
                ep, {
                    f"exact_br/{state}": value
                    for state, value in zip(initial_states,
                                            nash_conv_obj.br_values())
                })

            # DQN best response to uniform.
            pi_value = policy_value.PolicyValue(game, mfg_dist,
                                                joint_avg_policy)
            writer.write_scalars(
                ep, {
                    f"dqn_br/{state}": pi_value.eval_state(state)
                    for state in initial_states
                })

            if FLAGS.use_checkpoints:
                for agent in agents:
                    agent.save(FLAGS.checkpoint_dir)

        for p in range(game.num_players()):
            time_step = envs[p].reset()
            while not time_step.last():
                agent_output = agents[p].step(time_step)
                action_list = [agent_output.action]
                time_step = envs[p].step(action_list)

            # Episode is over, step all agents with final info state.
            agents[p].step(time_step)

    # Make sure all values were written.
    writer.flush()
Beispiel #9
0
def create_default_writer():
    return metric_writers.create_default_writer()  # pylint: disable=unreachable
Beispiel #10
0
    def run_train(self,
                  experiment_dir,
                  work_unit_dir,
                  rng,
                  yield_results=False):
        """Train a Dream Field and save results to work_unit_dir."""
        t_start = time.time()
        config = self.config

        logging.info('Local devices: %s', jax.local_devices())
        logging.info('All devices: %s', jax.devices())

        ## Load CLIP
        encode_image, encode_text, preprocess_image, tokenize_fn = (
            helpers.load_image_text_model(config.loss_model))

        ## Pick a prompt
        template = config.get('query_template', '{query}')
        query = template.format(query=config.query)
        z_clip = encode_text(tokenize_fn(query))

        ## Encode retrieval set
        if config.queries_r:
            if config.retrieve_models[0] == config.loss_model:
                # Reuse loss model.
                encode_image_r, preprocess_image_r = encode_image, preprocess_image
                encode_text_r, tokenize_fn_r = encode_text, tokenize_fn
            else:
                # Load new model.
                encode_image_r, encode_text_r, preprocess_image_r, tokenize_fn_r = (
                    helpers.load_image_text_model(config.retrieve_models[0]))

            if config.query not in config.queries_r:
                config.queries_r.append(config.query)
            z_clip_r = encode_text_r(tokenize_fn_r(config.queries_r))
            true_idx_r = config.queries_r.index(config.query)
            assert true_idx_r >= 0  # Input query must be set of retrieval queries.

            del encode_text_r, tokenize_fn_r  # Clean up retrieval text encoder.

        del encode_text, tokenize_fn  # Clean up text encoder.

        ## Scene origin manually tracked
        scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999)

        def train_step(state, rays, key, *multistep_constants):
            """Perform a training iteration, optionally composed of multiple substeps.

      Using multiple substeps slightly reduces training time, but only one
      substep per training iteration is used in experiments.

      Args:
        state: Optimizer state.
        rays: Camera rays for rendering, shared across all substeps.
        key: PRNGKey for random number generation (e.g. for augmentations).
        *multistep_constants: Training constants that can vary across substeps.
          7 arrays of constants of length config.substeps are expected:
            (1) lrs: learning rates
            (2) scs: scale factor for integrated positional encoding. Larger
                scales lead to a blurrier appearance. A constant sc=1 is the
                standard mip-NeRF IPE, and used by Dream Fields.
            (3) sns: standard deviation of pre-activation noise for NeRF
                density. Dream Fields use sn=0.
                  density(x) = softplus(s(x) + eps), eps ~ N(0, sn^2)
            (4) mrs: norm of radiance mask, defining scene bounds.
            (5) betas: scale of beta prior loss. Dream Fields use beta=0.
            (6) acct: transmittance loss hyperparameter, defining the target
                average opacity. This is 1 - tau (target transmittance).
            (7) acclam: weight of transmittance loss.

      Returns:
        state: Updated optimizer state.
        last_augs: Augmented views of renderings from the last substep.
        mean_losses: Dictionary of losses averaged over replicas and substeps.
        scene_origin: Updated origin of the scene, based on the center of mass.
      """
            # NOTE(jainajay): rays are shared across all substeps
            pmean = functools.partial(jax.lax.pmean, axis_name='batch')
            psum = functools.partial(jax.lax.psum, axis_name='batch')

            def loss_fn(params, key, sc, sn, mr, beta, acct, acclam):
                render_key, aug_key, key = random.split(key, 3)

                # Render from nerf
                (rgb_est_flat, _,
                 acc_est_flat), aux = render_rays(rays=rays,
                                                  variables=params,
                                                  rng=render_key,
                                                  config=config,
                                                  sc=sc,
                                                  sigma_noise_std=sn,
                                                  mask_rad=mr,
                                                  origin=scene_origin.value,
                                                  train=True)
                rgb_est = scene.gather_and_reshape(rgb_est_flat,
                                                   config.render_width, 3)
                acc_est = scene.gather_and_reshape(acc_est_flat,
                                                   config.render_width, 1)
                # Make augmentations process specific
                aug_key = random.fold_in(aug_key, pid)
                # Perform augmentations and resize to clip_width
                augs = augment.augment_rendering(config, rgb_est, acc_est,
                                                 aug_key)

                # Run through CLIP
                z_est = encode_image(preprocess_image(augs))
                clip_loss = -(z_est * z_clip).sum(-1).mean()
                total_loss = clip_loss

                transparency_loss = config.get('transparency_loss', None)
                acc_mean = np.mean(acc_est)
                aux['losses']['acc_mean'] = acc_mean
                if transparency_loss == 'neg_lam_transmittance_clipped':
                    # Compute the Dream Fields transmittance loss for scene sparsity.
                    trans_mean = 1 - acc_mean
                    trans_mean_clipped = np.minimum(1 - acct, trans_mean)
                    reg = acclam * trans_mean_clipped
                    total_loss -= reg

                    aux['losses']['trans_mean_clipped'] = trans_mean_clipped
                    aux['losses']['acc_reg_additive'] = reg
                else:
                    assert transparency_loss is None

                # Compute a sparsity loss by placing a bimodal beta prior on the
                # per-pixel transmittance. This prior was proposed by Lombardi et al
                # in "Neural Volumes: Learning Dynamic Renderable Volumes from Images"
                # and is used only in ablations.
                beta_loss = np.mean(
                    np.log(np.maximum(1e-6, acc_est_flat)) +
                    np.log(np.maximum(1e-6, 1. - acc_est_flat)))
                total_loss += beta_loss * beta

                # Compute a weighted mean of each replica's estimated scene origin,
                # since replicas get a different subset of rays
                total_sigma = psum(aux['scene_origin_sigma'])
                aux['scene_origin'] = psum(aux['scene_origin'] *
                                           aux['scene_origin_sigma'] /
                                           total_sigma)
                # Compute loss that pushes scene content to 0 origin. We set the loss
                # weight zero_origin_lam = 0 in experiments so the loss is just for
                # logging how far the origin has drifted.
                origin_loss = np.sum(np.square(aux['scene_origin']))
                if config.get('zero_origin_lam', 0.):
                    total_loss += config.zero_origin_lam * origin_loss

                aux['losses'].update({
                    'clip_loss': clip_loss,
                    'beta_loss': beta_loss,
                    'origin_loss': origin_loss,
                    'loss': total_loss,
                })
                aux['augs'] = augs
                return total_loss, aux

            grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

            # Scan over substeps
            def body_fn(state, step_constants):
                lr, step_constants = step_constants[0], step_constants[1:]
                grad_fn_key, _ = random.split(key, 2)
                (_, aux), grad = grad_fn(state.target, grad_fn_key,
                                         *step_constants)
                grad = pmean(grad)  # all-reduce grad
                aux['losses'] = pmean(aux['losses'])
                aux['losses']['grad_norm'] = helpers.tree_norm(grad)
                state = state.apply_gradient(grad, learning_rate=lr)
                return state, aux

            assert len(multistep_constants) == 7
            multistep_constants = np.array(multistep_constants).T

            if config.substeps == 1:
                state, aux = body_fn(state, np.squeeze(multistep_constants))
                last_augs = aux['augs']
            else:
                state, aux = jax.lax.scan(body_fn, state, multistep_constants)
                # Augmentations from last substep.
                # Shape: [n_local_aug, clip_width, clip_width, 3]
                last_augs = aux['augs'][-1]

            # Average each type of loss over substeps
            mean_losses = jax.tree_map(np.mean, aux['losses'])
            return state, last_augs, mean_losses, aux['scene_origin']

        train_pstep = jax.pmap(train_step,
                               axis_name='batch',
                               in_axes=(0, 0, 0, None, None, None, None, None,
                                        None, None))

        onp.random.seed(config.seed)

        n_device = jax.local_device_count()
        pid = jax.process_index()
        logging.info('n_device %d', n_device)
        ## Modified NeRF architecture, with swish, softplus, skips.
        variables, render_rays = helpers.init_nerf_model(
            rng.advance(1), config)
        state = flax.optim.Adam(config.lr0,
                                eps=config.adam_eps).create(variables)

        ## Try to restore a checkpoint.
        restore_dir = config.get('restore_dir', experiment_dir)
        restore_dir = os.path.join(restore_dir,
                                   os.path.basename(work_unit_dir))
        if checkpoints.latest_checkpoint(restore_dir):
            restored = checkpoints.restore_checkpoint(restore_dir,
                                                      target={
                                                          'origin':
                                                          np.zeros(3),
                                                          'state': state,
                                                          'vars': variables
                                                      })
            scene_origin.value = onp.array(restored['origin'])
            state = restored['state']
            variables = restored['vars']
            logging.info('restored checkpoint from step %d', state.state.step)
        else:
            logging.info('did not find checkpoint in %s', restore_dir)

        ## Replicate state.
        step_init = state.state.step
        helpers.defragment()
        state = flax.jax_utils.replicate(state, jax.devices())
        helpers.defragment()

        ## pmap'd rendering for test time evaluation.
        kwargs_test = dict(rng=None, sigma_noise_std=0.)
        config_test = ml_collections.ConfigDict(config)
        config_test.update(config.test)
        config_test_hq = ml_collections.ConfigDict(config_test)
        config_test_hq.update(config.test_hq)

        @functools.partial(jax.pmap, in_axes=(0, None, None, None))
        def render_test_p(rays, variables, sc=1., mr=1.):
            return render_rays(rays=rays,
                               variables=variables,
                               sc=sc,
                               mask_rad=mr,
                               origin=scene_origin.value,
                               config=config_test,
                               **kwargs_test)[0]

        @functools.partial(jax.pmap, in_axes=(0, None, None, None))
        def render_test_hq_p(rays, variables, sc=1., mr=1.):
            return render_rays(rays=rays,
                               variables=variables,
                               config=config_test_hq,
                               sc=sc,
                               mask_rad=mr,
                               origin=scene_origin.value,
                               **kwargs_test)[0]

        def render_test(rays, variables, sc=1., mr=1., hq=False):
            sh = rays[0].shape
            rays = [
                x.reshape((jax.device_count(), -1) + x.shape[1:]) for x in rays
            ]
            if hq:
                out = render_test_hq_p(rays, variables, sc, mr)
            else:
                out = render_test_p(rays, variables, sc, mr)
            out = [x.reshape(sh[:-1] + (-1, )) for x in out]
            return out

        def render_loop(rays, variables, sc=1., mr=1., chunk=2**13, hq=False):
            sh = list(rays[0].shape[:-1])
            rays = [x.reshape((-1, ) + x.shape[-1:]) for x in rays]
            outs = [
                render_test([x[i:i + chunk] for x in rays],
                            variables,
                            sc,
                            mr,
                            hq=hq) for i in range(0, rays[0].shape[0], chunk)
            ]
            outs = [
                np.reshape(np.concatenate([z[i] for z in outs]), sh + [-1])
                for i in range(3)
            ]
            return outs

        ## Training loop
        t_total = 0.
        logging.info('Experiment dir %s', experiment_dir)
        logging.info('Work unit dir %s', work_unit_dir)
        gfile.makedirs(work_unit_dir)

        # Set up metric writer
        writer = metric_writers.create_default_writer(
            work_unit_dir,
            asynchronous=True,
            just_logging=jax.process_index() > 0)
        if jax.process_index() == 0:
            train_config = config.copy_and_resolve_references()
            log.write_config_json(train_config, work_unit_dir)

        # Scale instrinsics to different resolutions.
        hwf_clip_r = scene.scale_intrinsics(config.retrieve_widths[0])
        hwf_base = scene.scale_intrinsics(config.render_width)
        hwf_video = scene.scale_intrinsics(300.)
        hwf_video_hq = scene.scale_intrinsics(400.)

        # JIT compile ray generation
        @jax.jit
        def camera_ray_batch_base(p, focal_mult):
            return scene.camera_ray_batch(p, *hwf_base[:2],
                                          hwf_base[2] * focal_mult)

        @jax.jit
        def sample_pose_focal(key):
            return scene.sample_camera(key, config.th_range, config.phi_range,
                                       config.rad_range,
                                       config.focal_mult_range)

        shard_rays_jit = jax.jit(functools.partial(scene.shard_rays))

        def sample_iter_data(key, step):
            # Sample pose, focal length multiplier.
            pose, rad, focal_mult = sample_pose_focal(key)

            # Generate rays, shaped for pmap over devices.
            rays = camera_ray_batch_base(pose, focal_mult)
            rays_in = shard_rays_jit(rays)
            # Select rays for this process
            rays_in = jax.tree_map(lambda x: x[pid], rays_in)

            substeps = np.arange(start=step,
                                 stop=step + config.substeps,
                                 step=1)

            # mip-NeRF scale annealing.
            decays = config.mipnerf.decay_start * (
                1 - substeps / config.mipnerf.decay_iters)
            scs = np.maximum(1., 2**decays)

            # Sigma noise annealing.
            sns = schedule.sigma_noise_std_fn(substeps,
                                              i_split=config.sn_i_split,
                                              sn0=config.sn0,
                                              sn1=config.sn1)

            # Scene bounds annealing.
            mrs = schedule.mask_rad_fn(substeps,
                                       i_split=config.mr_i_split,
                                       mr0=config.mr0,
                                       mr1=config.mr1)

            # Anneal target opacity (1 - transmittance).
            accts = schedule.anneal_exponentially(substeps,
                                                  config.acc_target_i_split,
                                                  config.acc_target0,
                                                  config.acc_target1)
            # The area of an object on the image plane grows with the focal length
            # and shrinks with increasing camera radius. Scale target opacity
            # proportionally with the squared focal multiplier and inversely
            # proportionally with the squared camera radius. For consistency with
            # early experiments that did not use this scaling, we also scale by a
            # constant, 1 / (4^2 * 1.2).
            acct_scaling = focal_mult**2 / ((rad / 4.)**2) / 1.2
            accts = np.minimum(1., acct_scaling * accts)
            acclams = np.where(substeps < config.acc_lam_after, 0.,
                               config.acc_lam)

            # Beta prior encourages either 0 or 1 opacity for rays
            betas = np.where(substeps < config.beta_after, .0,
                             config.get('beta_lam', .001))

            # Learning rate schedule.
            # NOTE: vectorized calculation of lrs doesn't work with multiple substeps
            lrs = schedule.lr_fn(substeps,
                                 i_split=config.lr_i_split,
                                 i_end=config.iters,
                                 lr0=config.lr0,
                                 lr1=config.lr1,
                                 lr2=config.lr2,
                                 cosine_decay=config.lr_cosine_decay)

            return substeps, rays_in, lrs, scs, sns, mrs, betas, accts, acclams

        pbar = tqdm.trange(step_init,
                           config.iters + config.substeps,
                           config.substeps,
                           desc='training')
        for i in pbar:
            t = time.time()

            substeps, rays_in, lrs, scs, sns, mrs, betas, accts, acclams = (
                sample_iter_data(rng.advance(1), i))
            l = substeps[-1]

            keys_pstep = rng.split(n_device)
            # NOTE: loss is averaged across substeps.
            new_state, augs, mean_losses, new_scene_origin = train_pstep(
                state, rays_in, keys_pstep, lrs, scs, sns, mrs, betas, accts,
                acclams)

            # Reduce across devices
            mean_losses = jax.tree_map(np.mean, mean_losses)

            # Gradient skipping if nan.
            if (helpers.all_finite_tree(mean_losses)
                    and helpers.all_finite_tree(new_state)):
                state = new_state
            else:
                logging.warn(
                    'Skipping update on step %d. non-finite loss or state', i)
                continue

            # Update scene origin.
            if config.get('ema_scene_origin', False):
                if helpers.all_finite(new_scene_origin):
                    scene_origin.update(new_scene_origin[0])
                else:
                    logging.warn(
                        'Skipping origin update on step %d. '
                        'non-finite origin. old: %s skipped update: %s', i,
                        scene_origin.value, new_scene_origin)

            ## Yield results, for display in colab.
            augs = augs.reshape(
                -1, *augs.shape[2:])  # devices, n_localaug, HWC->BHWC
            if yield_results:
                yield mean_losses, augs, scene_origin.value
            else:
                yield None
            pbar.set_description(f'Loss: {mean_losses["loss"]:.4f}')

            ## Logging.
            if i == 0:
                continue

            t_total += time.time() - t

            if i % config.log_scalars_every == 0:
                scalars = {
                    f'losses/{key}': value
                    for key, value in mean_losses.items()
                }
                scalars.update({
                    'schedule/mipnerf_scale':
                    scs[-1],
                    'schedule/lr':
                    lrs[-1],
                    'schedule/mask_rad':
                    mrs[-1],
                    'schedule/sigma_noise_std':
                    sns[-1],
                    'schedule/beta':
                    betas[-1],
                    'schedule/acc_target':
                    accts[-1],
                    'schedule/acc_lam':
                    acclams[-1],
                    'origin/x':
                    scene_origin.value[0],
                    'origin/y':
                    scene_origin.value[1],
                    'origin/z':
                    scene_origin.value[2],
                    'origin/norm':
                    np.linalg.norm(scene_origin.value),
                })

                secs_per_iter = t_total / (l - step_init)
                iters_per_sec = (l - step_init) / t_total
                wall = time.time() - t_start
                scalars.update({
                    'system/wall': wall,
                    'system/secs_per_iter': secs_per_iter,
                    'system/iters_per_sec': iters_per_sec,
                })

            if i % config.render_every == 0:
                variables = helpers.state_to_variables(state)
                cam2world = scene.pose_spherical(30., -45., 4.)
                rays = scene.camera_ray_batch(cam2world, *hwf_clip_r)

                # Render with no scale manipulation.
                outs = render_loop(rays, variables, sc=1., mr=mrs[-1], hq=True)
                outs = [np.squeeze(x) for x in outs]
                step_images = {
                    'render/rgb': outs[0][None],
                    'render/depth': outs[1][None, Ellipsis, None],
                    'render/acc': outs[2][None, Ellipsis, None],
                }

                # Compute retrieval metric.
                if config.queries_r:
                    z_est = encode_image_r(preprocess_image_r(outs[0][None]))
                    cosine_sim = (z_est * z_clip_r).sum(
                        -1)  # 1d, num retrieval queries
                    log_prob = nn.log_softmax(cosine_sim)
                    prefix = f'val/{config.retrieve_models[0]}/retrieve_'
                    scalars.update({
                        f'{prefix}cosine_sim':
                        cosine_sim[true_idx_r],
                        f'{prefix}loss':
                        -log_prob[true_idx_r],
                        f'{prefix}acc':
                        (np.argmax(cosine_sim) == true_idx_r).astype(float)
                    })

                augs_tiled = log.make_image_grid(augs[:8])
                step_images['render/augmentations'] = augs_tiled

                fig = plt.figure()
                plt.imshow(1. / np.maximum(config.near, outs[1]))
                plt.colorbar()
                plt.title('disparity')
                disparity = log.plot_to_image(fig)
                step_images['render/disparity'] = disparity

                writer.write_images(step=l, images=step_images)

                if config.render_lq_video and config.video_every and (
                        i % config.video_every == 0 or i + 1 == config.iters):

                    def rays_theta(th):
                        cam2world = scene.pose_spherical(th, -30., 4.)
                        return scene.camera_ray_batch(cam2world, *hwf_video)

                    th_range = np.linspace(0, 360, 60, endpoint=False)
                    frames_all = [
                        render_loop(rays_theta(th),
                                    variables,
                                    scs[-1],
                                    mrs[-1],
                                    hq=False)
                        for th in tqdm.tqdm(th_range, desc='render video')
                    ]

                    videos = [[np.squeeze(f[i]) for f in frames_all]
                              for i in range(3)]
                    for video, label in zip(videos, 'rgb depth acc'.split()):
                        scale = (label == 'depth')
                        log.log_video(None,
                                      video,
                                      'frames',
                                      label,
                                      l,
                                      work_unit_dir,
                                      scale=scale)

            if i % config.log_scalars_every == 0:
                writer.write_scalars(step=l, scalars=scalars)

            if i % config.flush_every == 0:
                writer.flush()

            defrag_every = config.get('defragment_every', default=0)
            if defrag_every and i % defrag_every == 0:
                helpers.defragment()

            if config.get(
                    'checkpoint_every') and i % config.checkpoint_every == 0:
                saved_path = checkpoints.save_checkpoint(
                    ckpt_dir=work_unit_dir,
                    target={
                        'state': flax.jax_utils.unreplicate(state),
                        'vars': helpers.state_to_variables(state),
                        'origin': np.array(scene_origin.value),
                    },
                    step=l,
                    keep=1,
                    overwrite=True,
                    keep_every_n_steps=config.get('keep_every_n_steps', None))
                logging.info('saved checkpoint to %s', saved_path)

            # Make a higher res, higher frame rate video.
            if config.render_hq_video and (config.get('hq_video_every', None)
                                           and i % config.hq_video_every == 0
                                           or i == config.iters):

                my_rays = lambda c2w: scene.camera_ray_batch(
                    c2w, *hwf_video_hq)
                th_range = np.linspace(0, 360, 240, endpoint=False)
                poses = [scene.pose_spherical(th, -30., 4.) for th in th_range]
                variables = helpers.state_to_variables(state)
                frames_all = [
                    render_loop(my_rays(pose),
                                variables,
                                1.,
                                config.mr1,
                                hq=True)
                    for pose in tqdm.tqdm(poses, 'render hq video')
                ]

                videos = [[onp.array(np.squeeze(f[j])) for f in frames_all]
                          for j in range(3)]
                meta_path = os.path.join(work_unit_dir, 'meta_hq.npy')
                with gfile.GFile(meta_path, 'wb') as f:
                    logging.info(
                        'saving metadata for rendered hq frames to %s',
                        meta_path)
                    onp.save(
                        f,
                        dict(poses=onp.array(poses),
                             hwf=onp.array(hwf_video_hq)))
                for video, label in zip(videos, 'rgb depth acc'.split()):
                    scale = (label == 'depth')
                    log.log_video(None,
                                  video,
                                  'frames_hq',
                                  label,
                                  i,
                                  work_unit_dir,
                                  scale=scale)

        writer.flush()
        writer.close()
        logging.info('%f sec elapsed total', time.time() - t_start)
Beispiel #11
0
def train(config: ml_collections.ConfigDict):
  """Run training."""

  # Establish host information
  local_device_count = jax.local_device_count()
  host_count = jax.process_count()
  host_id = jax.process_index()

  task = task_registry.get_registered_task(config.task_name)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)

  model_config = ml_collections.FrozenConfigDict(config.model_config)
  model = task.build_model(model_config)

  # Initialization needs to be pmapped because models use collective ops.
  # Create dummy input
  dummy_input = {
      key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim)
      for key, value in task.dummy_input(config).items()
  }

  rng, init_rng = jax.random.split(rng)
  init_rng = jax.random.split(init_rng, local_device_count)

  logging.info('Initializing model.')
  initial_variables = jax.pmap(
      model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input,
                                                         True)
  logging.info('Finished initializing model.')
  initial_variables = initial_variables.unfreeze()

  if config.load_weights is not None:
    logging.info('Loading model weights from file')
    loaded_variables = task.load_weights(config)
    unexpected, missing = checkpoint_utils.merge_nested_dicts(
        initial_variables, loaded_variables)
    logging.info('*** Unexpected features: ***')
    for feature_name in unexpected:
      logging.info('\t%s', feature_name)
    logging.info('*** Missing features: ***')
    for feature_name in missing:
      logging.info('\t%s', feature_name)

  model_vars = {
      key: value for key, value in initial_variables.items() if key != 'params'
  }

  learning_rate_fn = optim_utils.create_learning_rate_scheduler(
      learning_rate=config.learning_rate,
      warmup=config.warmup,
      warmup_steps=config.get('warmup_steps', None),
      linear_decay=config.linear_decay,
      max_steps=config.num_train_steps,
      decay_minimum_factor=config.get('decay_minimum_factor', None),
  )

  if config.weight_decay_exclude is not None:
    decay_mask = optim_utils.create_dict_mask(initial_variables['params'],
                                              config.weight_decay_exclude)
  else:
    decay_mask = None
  tx = optax.adamw(
      learning_rate=learning_rate_fn,
      weight_decay=config.weight_decay,
      b1=0.9,
      b2=0.999,
      eps=1e-6,
      mask=decay_mask)
  if config.grad_clip is not None:
    tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip))

  ignore_k_nans = config.get('ignore_k_nans')
  if ignore_k_nans is not None:
    tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans)

  loss_fn = task.make_loss_fn(config)
  train_state = ts.TrainState.create(
      apply_fn=loss_fn,
      params=jax_utils.unreplicate(initial_variables['params']),
      tx=tx,
  )

  # We access model params only from train state.
  del initial_variables

  # Restore unreplicated train state from last checkpoint
  train_state = checkpoints.restore_checkpoint(config.model_dir, train_state)
  # Grab last step.
  start_step = int(train_state.step)

  writer = metric_writers.create_default_writer(
      config.model_dir, just_logging=jax.process_index() > 0)
  if start_step == 0:
    writer.write_hparams(config.to_dict())

  dropout_rngs = jax.random.split(rng, local_device_count)

  del rng

  # Load datasets
  logging.info('Loading dataset.')

  # Make sure we don't re-use same data if we load weights or checkpoint
  seed = config.seed + start_step
  if config.load_weights:
    seed = seed + hash(config.load_weights)

  name_to_features = task.get_name_to_features(config)
  preprocess_fn = task.make_preprocess_fn(config)
  collater_fn = task.make_collater_fn(config)

  train_data = data_utils.load_multi_dataset(
      datasets_config=config.train_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=True,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
  )
  train_iter = iter(train_data)

  pad_eval = config.get('pad_eval', False)
  if pad_eval:
    logging.info('Eval data is padded such that none of samples are dropped.')
  else:
    logging.warn('Eval data is NOT padded -- some samples might be dropped.')

  eval_data = data_utils.load_multi_dataset(
      datasets_config=config.eval_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=False,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
      pad_eval=pad_eval,
  )
  eval_data = list(eval_data)
  logging.info('Loaded %d samples for evaluation.', len(eval_data))

  # Setup postprocessing_fn for saving samples occasionally.
  if config.get('save_samples_every_steps') is not None:
    if config.get('save_samples_every_steps') % config.eval_every_steps != 0:
      raise ValueError(
          '`eval_every_steps` must divide `save_samples_every_steps`.')
    postprocessing_fn = task.make_output_postprocess_fn(config)

  # Training loop
  logging.info('Starting training.')

  # Replicate train state.
  train_state = jax_utils.replicate(train_state)

  # compile multidevice versions of train/eval/predict step
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model_config=model_config,
      ),
      axis_name='batch',
      donate_argnums=(0,),
  )  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          model_config=model_config,
      ),
      axis_name='batch')

  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)

  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and perform a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = jax.tree_map(jnp.asarray, train_iter.get_next())
        train_state, metrics = p_train_step(
            train_state,
            model_vars,
            batch,
            dropout_rngs,
        )
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step)
      for h in hooks:
        h(step)

        # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          summary = metric_utils.process_metrics(metrics_sums, prefix='train')
          summary['learning_rate'] = learning_rate_fn(step)

          writer.write_scalars(step, summary)
          train_metrics = []

          with report_progress.timed('eval'):
            eval_results, eval_auxiliary = evaluate(
                eval_step_fn=p_eval_step,
                train_state=train_state,
                model_vars=model_vars,
                eval_data=eval_data,
            )
            writer.write_scalars(step, eval_results)

            if config.get('save_samples_every_steps') is not None:
              with report_progress.timed('save_samples'):
                if config.get('save_first_batch_only', 'True'):
                  postprocessing_input = [eval_auxiliary[0]]
                eval_processed = [
                    postprocessing_fn(batch, auxiliary_output)
                    for batch, auxiliary_output in eval_auxiliary
                ]
                data_utils.save_samples_to_json(eval_processed, config, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step % config.checkpoint_every_steps == 0 or is_last_step)
      if (config.save_checkpoints and save_checkpoint and
          jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving checkpoint at step %s', step)
          checkpoints.save_checkpoint(
              config.model_dir,
              jax_utils.unreplicate(train_state),
              step,
              keep=config.get('keep_checkpoints', 1),
              keep_every_n_steps=config.get('keep_checkpoint_every_steps'),
          )

      save_model = (
          config.save_every_steps and
          (step % config.save_every_steps == 0 or is_last_step) and step != 0)
      if (save_model and jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving weights at step %s', step)
          save_path = os.path.join(config.model_dir, 'weights',
                                   'step' + str(step))
          # By default, save only encoder weights
          weights = jax_utils.unreplicate(train_state).params['encoder']
          checkpoint_utils.save_weights(save_path, weights)
Beispiel #12
0
  def run_train(self, experiment_dir, work_unit_dir,
                rng):
    """Training loop with fixed number of steps and checkpoint every steps."""
    del experiment_dir  # unused
    tf.io.gfile.makedirs(work_unit_dir)

    config = self.config

    total_bs = config.train.batch_size
    assert total_bs % jax.device_count() == 0, (
        f'num total devices {jax.device_count()} must divide the batch size '
        f'{total_bs}')
    device_bs = total_bs // jax.device_count()
    logging.info('total_bs=%d device_bs=%d', total_bs, device_bs)

    # Logging setup
    writer = metric_writers.create_default_writer(
        work_unit_dir, just_logging=jax.host_id() > 0)
    if jax.host_id() == 0:
      utils.write_config_json(config, os.path.join(work_unit_dir,
                                                   'config.json'))

    # Build input pipeline
    logging.info('Substeps per training step: %d', config.train.substeps)
    train_ds = self.dataset.get_tf_dataset(
        split='train',
        batch_shape=(
            jax.local_device_count(),  # for pmap
            config.train.substeps,  # for lax.scan over multiple substeps
            device_bs,  # batch size per device
        ),
        global_rng=jax.random.PRNGKey(config.seed),
        repeat=True,
        shuffle=True,
        augment=True,
        shard_id=jax.host_id(),
        num_shards=jax.host_count())
    train_iter = utils.numpy_iter(train_ds)
    eval_ds = self.dataset.get_tf_dataset(
        split='eval',
        batch_shape=(jax.local_device_count(), device_bs),
        global_rng=jax.random.PRNGKey(config.seed),
        repeat=True,
        shuffle=True,
        augment=False,
        shard_id=jax.host_id(),
        num_shards=jax.host_count())
    eval_iter = utils.numpy_iter(eval_ds)

    samples_shape = (device_bs, *self.dataset.data_shape)

    self.p_gen_samples = utils.dist(
        functools.partial(self._gen_samples, samples_shape=samples_shape),
        accumulate='concat',
        axis_name='batch')

    # Set up model and training state
    state = jax.device_get(self.make_init_state())
    checkpoint_dir = os.path.join(work_unit_dir, 'checkpoints')
    state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    initial_step = int(state.step)
    state = flax.jax_utils.replicate(state)

    # Training step
    train_step = functools.partial(self.step_fn, next(rng), True)
    train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
    train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

    # Eval step (does not modify parameters; no substeps)
    eval_base_rng = next(rng)

    # Training loop
    logging.info('Entering training loop at step %i', initial_step)
    utils.assert_synced(state)
    last_log_time = last_ckpt_time = time.time()
    prev_step = initial_step

    with metric_writers.ensure_flushes(writer):
      for batch in train_iter:

        state, metrics = train_step(state, batch)
        new_step = int(state.step[0])
        assert new_step == prev_step + config.train.substeps

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                            new_step)
        # Log metrics
        if new_step % config.train.log_loss_every_steps == 0:
          # Unreplicate metrics, average over substeps, and cast to python float
          metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))

          def avg_over_substeps(x):
            assert x.shape[0] == config.train.substeps
            return float(x.mean(axis=0))

          metrics = jax.tree_map(avg_over_substeps, metrics)
          metrics['train/steps_per_sec'] = float(
              config.train.log_loss_every_steps / (time.time() - last_log_time))
          writer.write_scalars(new_step, metrics)
          last_log_time = time.time()

        # Eval
        should_eval = new_step % config.train.eval_every_steps == 0
        if prev_step == 0 or should_eval:
          # Samples

          samples_to_log = {
              'eval/samples':
                  self.get_model_samples(
                      params=state.ema_params, rng=next(rng))
          }

          if samples_to_log:
            assert all(v.shape == (total_bs, *self.dataset.data_shape)
                       for v in samples_to_log.values())
            # tf.summary.image asks for a batch, so insert a new axis
            writer.write_images(
                new_step, {
                    k: utils.np_tile_imgs(v.astype('uint8'))[None, :, :, :]
                    for k, v in samples_to_log.items()
                })

          # Eval metrics
          if config.train.get('calc_eval_metrics', True):
            eval_metrics = self._calc_eval_metrics(
                state=state,
                eval_iter=eval_iter,
                eval_steps=config.train.get('eval_number_steps',
                                            self.dataset.num_eval // total_bs),
                eval_base_rng=eval_base_rng,
                total_bs=total_bs)
            if eval_metrics is not None:
              writer.write_scalars(new_step, eval_metrics)

        # Checkpointing: only if checkpoint_every_secs is not None.
        if config.train.checkpoint_every_secs is not None:
          should_ckpt = (
              time.time() - last_ckpt_time >=
              config.train.checkpoint_every_secs)
          should_ckpt = (
              prev_step == 0 or new_step == config.train.num_train_steps or
              should_ckpt)
        else:
          should_ckpt = False

        if should_ckpt and jax.host_id() == 0:
          checkpoints.save_checkpoint(
              checkpoint_dir,
              flax.jax_utils.unreplicate(state),
              step=new_step,
              keep=3)
          last_ckpt_time = time.time()

        # Keep extra checkpoints without removal. Training does not resume
        # from these checkpoints.
        if (('retain_checkpoint_every_steps' in config.train) and
            ((new_step % config.train.retain_checkpoint_every_steps == 0) or
             (new_step == config.train.num_train_steps)) and
            (jax.host_id() == 0)):
          # Below, overwrite=True because training might resume from a
          # checkpoint from an earlier step than the latest retained checkpoint,
          # causing the latest retained checkpoint to be overwritten.
          checkpoints.save_checkpoint(
              os.path.join(work_unit_dir, 'retained_checkpoints'),
              flax.jax_utils.unreplicate(state),
              step=new_step,
              keep=int(1e10),
              overwrite=True)

        prev_step = new_step
        if new_step == config.train.num_train_steps:
          logging.info('Finished training for %d iterations.', new_step)
          break
Beispiel #13
0
def train(env, agent, loss_func, horizon, config, workdir=None):
  """Main training loop.

  config
    - num_episodes
    - episodes_per_eval
    - training_env_batch_size
    - eval_env_batch_size = 32

    - optimizer
    - learning_rate

    - seed = 1
  """
  print(config)
  if workdir is not None:
    writer = metric_writers.create_default_writer(
        logdir=workdir, just_logging=jax.process_index() != 0)
    writer.write_hparams(dict(config))

  key = jax.random.PRNGKey(config.seed)
  key_train_agent, key_eval_agent, key_train_env, key_eval_env, key_train, key = jax.random.split(key, 6)
  key_train_envs = jax.random.split(key_train_env, config.training_env_batch_size)
  key_train_agents = jax.random.split(key_train_agent, config.training_env_batch_size)
  key_eval_envs = jax.random.split(key_eval_env, config.eval_env_batch_size)
  key_eval_agents = jax.random.split(key_eval_agent, config.eval_env_batch_size)

  #TODO(danielsuo): The following vmap code does not work.
  train_env_start_states, train_env_init_obs = jax.vmap(env.init)(key_train_envs)
  eval_env_start_states, eval_env_init_obs = jax.vmap(env.init)(key_eval_envs)
  print(train_env_start_states)
  print(train_env_init_obs)

  # qtrain_init_list = list(map(env.init, key_train_envs))
  # qtrain_env_start_states = [a for (a,_) in qtrain_init_list]
  # qtrain_env_init_obs = [b for (_,b) in qtrain_init_list]
  # print(qtrain_env_start_states)
  # print(qtrain_env_init_obs)
  # eval_init_list = list(map(env.init, key_eval_envs))
  # eval_env_start_states = [a for (a,_) in eval_init_list]
  # eval_env_init_obs = [b for (_,b) in eval_init_list]

  train_agent_start_states = jax.vmap(agent.init)(key_train_agents)
  eval_agent_start_states = jax.vmap(agent.init)(key_eval_agents)

  if config.optimizer == "Adam":
    optim = optax.adam(learning_rate=config.learning_rate)
  else:
    # default is SGD
    optim = optax.sgd(learning_rate=config.learning_rate)
  optim_state = optim.init(agent)

  for episode in range(0, config.num_episodes, config.episodes_per_eval):
    # Eval Step
    tt = time.time()
    eval_rollouts = apg_parallel_rollouts(env, eval_env_start_states,
                                          eval_env_init_obs, agent,
                                          eval_agent_start_states, horizon,
                                          loss_func)
    test_score = eval_rollouts.losses.mean()
    print(f"TESTING episode {episode} - score:{test_score} - time:{time.time()-tt}")

    # Training Step
    tt = time.time()
    agent, optim_state, losses = train_chunk(config.episodes_per_eval, optim,
                                             optim_state, agent, env,
                                             train_env_start_states,
                                             train_env_init_obs,
                                             train_agent_start_states, horizon,
                                             loss_func)
    done_eps = episode + config.episodes_per_eval - 1
    print(f"TRAINING: episode {done_eps} - score:{losses[0]} - time {time.time() - tt}")

    if workdir is not None:
      for (i, loss) in enumerate(reversed(losses)):
        writer.write_scalars(episode+i, {"train_score": loss})
      writer.write_scalars(episode, {"test_score": test_score})
  return optim_state, agent
def main(_):
    config = FLAGS.config

    # Unpack total and warmup steps
    total_steps = config.total_and_warmup_steps[0]
    warmup_steps = config.total_and_warmup_steps[1]
    del config.total_and_warmup_steps
    config.total_steps = total_steps
    config.lr.warmup_steps = warmup_steps

    # Wandb and Checkpointing Setup
    output_dir = FLAGS.output_dir
    wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
    tf.io.gfile.makedirs(output_dir)
    logging.info('Saving checkpoints at %s', output_dir)

    # Dataset Split Flags
    dist_shift = config.distribution_shift
    print(f'Distribution Shift: {dist_shift}.')
    dataset_names, split_names = vit_utils.get_dataset_and_split_names(
        dist_shift)

    # LR / Optimization Flags
    print('wandb hyperparameters:')
    print({
        'batch_size': config.batch_size,
        'grad_clip_norm': config.grad_clip_norm,
        'weight_decay': config.weight_decay,
        'total_steps': config.total_steps,
        'lr': config.lr,
        'fast_weight_lr_multiplier': config.fast_weight_lr_multiplier
    })

    # Reweighting loss for class imbalance
    # class_reweight_mode = config.class_reweight_mode
    # if class_reweight_mode == 'constant':
    #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
    # else:
    #   class_weights = None

    # Shows the number of available devices.
    # In a CPU/GPU runtime this will be a single device.
    # In a TPU runtime this will be 8 cores.
    print('Number of Jax local devices:', jax.local_devices())

    # TODO(nband): fix sigmoid loss issues.
    assert config.get('loss', None) == 'softmax_xent'

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)
    tf.io.gfile.makedirs(output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing preprocessing function...')
    # Same preprocessing function for training and evaluation
    preproc_fn = preprocess_spec.parse(
        spec=config.pp_train, available_ops=preprocess_utils.all_ops())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_base_dataset = ub.datasets.get(dataset_names['in_domain_dataset'],
                                         split=split_names['train_split'],
                                         data_dir=config.get('data_dir'))
    train_dataset_builder = train_base_dataset._dataset_builder  # pylint:disable=protected-access
    train_ds = input_utils.get_data(
        dataset=train_dataset_builder,
        split=split_names['train_split'],
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preproc_fn,
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    # Load in-domain and OOD validation and/or test datasets.
    # Please specify the desired shift (Country Shift or Severity Shift)
    # in the config.
    eval_iter_splits = vit_utils.init_evaluation_datasets(
        use_validation=config.use_validation,
        use_test=config.use_test,
        dataset_names=dataset_names,
        split_names=split_names,
        config=config,
        preproc_fn=preproc_fn,
        batch_size_eval=batch_size_eval,
        local_batch_size_eval=local_batch_size_eval)

    ntrain_img = input_utils.get_num_examples(
        train_dataset_builder,
        split=split_names['train_split'],
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = ntrain_img // batch_size

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    model_dict = vit_utils.initialize_model('batchensemble', config)
    model = model_dict['model']
    ens_size = model_dict['ens_size']

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @functools.partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['batchensemble_head']['bias'] = jnp.full_like(
            params['batchensemble_head']['bias'],
            config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['batchensemble_head']['kernel'] = jnp.full_like(
                params['batchensemble_head']['kernel'], 0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @functools.partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels):
        tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                        images,
                                        train=False)

        loss_name = config.get('loss', 'sigmoid_xent')
        # TODO(dusenberrymw,zmariet): Clean up and generalize this.
        if loss_name == 'sigmoid_xent':
            ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
        else:  # softmax
            ens_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))

        losses = getattr(train_utils,
                         loss_name)(logits=ens_logits,
                                    labels=labels[:, :config.num_classes],
                                    reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = batch_size_eval

        metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits],
                                         axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    def batch_loss_fn(params, images, labels, rngs):
        logits, _ = model.apply({'params': flax.core.freeze(params)},
                                images,
                                train=True,
                                rngs=rngs)
        labels = jnp.tile(labels, (ens_size, 1))
        loss_fn = getattr(train_utils, config.get('loss', 'sigmoid_xent'))
        loss = jnp.mean(loss_fn(logits=logits, labels=labels))
        return loss, dict()

    @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0, 1))
    def update_fn(opt, rngs, lr, images, labels):
        return batchensemble_utils.update_fn_be(
            opt=opt,
            rngs=rngs,
            lr=lr,
            images=images,
            labels=labels,
            batch_loss_fn=batch_loss_fn,
            weight_decay_fn=weight_decay_fn,
            max_grad_norm_global=config.get('grad_clip_norm', None),
            fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier',
                                                 None))

    # Set config checkpoint resume path, if provided in args.
    if config.resume_checkpoint_path is not None:
        config.resume = config.resume_checkpoint_path

    reint_params = ('batchensemble_head/bias', 'batchensemble_head/kernel',
                    'batchensemble_head/fast_weight_alpha',
                    'batchensemble_head/fast_weight_gamma')
    if config.get('only_eval', False) or not config.get('reint_head', True):
        reint_params = []
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=rng,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=reint_params,
        config=config)
    train_loop_rngs = {'dropout': checkpoint_data.train_loop_rngs}
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))

    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax.jax_utils.replicate(opt_cpu)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    # train_loss = -jnp.inf
    # eval_loss = {
    #     eval_name: -jnp.inf for eval_name, _ in eval_iter_splits.items()}
    # fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        # TODO(zmariet): Find better way to cut down iteration advancement cost.
        if not config.get('disable_preemption_reproducibility', False):
            train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl, train_loop_rngs, lr_repl, train_batch['image'],
                    train_batch['labels'])

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                train_loop_rngs=train_loop_rngs,
                optimizer=opt_cpu,
                accumulated_train_time=accumulated_train_time)

            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)
            # Keep to return for reproducibility tests.
            # train_loss = train_measurements['training_loss']

        # Report validation performance
        if config.get('only_eval', False) or train_utils.itstime(
                step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation sets...')
            chrono.pause()

            all_eval_results = {}

            for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
                start_time = time.time()

                # Runs evaluation loop.
                results_arrs = {
                    'y_true': [],
                    'y_pred': [],
                    'y_pred_entropy': []
                }

                for _, batch in zip(range(eval_steps), eval_iter):
                    batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
                        evaluation_fn(opt_repl.target, batch['image'],
                                      batch['labels']))

                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.

                    # Here we parse batch_metric_args to compute uncertainty metrics.
                    logits, labels, _ = batch_metric_args
                    logits = np.array(logits[0])
                    probs = jax.nn.softmax(logits)

                    # From one-hot to integer labels.
                    int_labels = np.argmax(np.array(labels[0]), axis=-1)

                    probs = np.reshape(probs,
                                       (probs.shape[0] * probs.shape[1], -1))
                    int_labels = int_labels.flatten()
                    y_pred = probs[:, 1]
                    results_arrs['y_true'].append(int_labels)
                    results_arrs['y_pred'].append(y_pred)

                    # Entropy is computed at the per-epoch level (see below).
                    results_arrs['y_pred_entropy'].append(probs)

                results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                        axis=0)
                results_arrs['y_pred'] = np.concatenate(
                    results_arrs['y_pred'], axis=0).astype('float64')
                results_arrs['y_pred_entropy'] = vit_utils.entropy(
                    np.concatenate(results_arrs['y_pred_entropy'], axis=0),
                    axis=-1)

                time_elapsed = time.time() - start_time
                results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
                results_arrs['dataset_size'] = eval_steps * batch_size_eval

                all_eval_results[eval_name] = results_arrs

            per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
                dataset_split_to_containers=all_eval_results,
                is_deterministic=True,
                num_bins=15,
                return_per_pred_results=True)

            # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
            # dataset. Flatten this dict so we can pass to the writer and remove empty
            # entries.
            flattened_metric_results = {}
            for dic in metrics_results.values():
                for key, value in dic.items():
                    if value is not None:
                        flattened_metric_results[key] = value
            writer.write_scalars(step, flattened_metric_results)

            # Optionally log to wandb
            if config.use_wandb:
                wandb.log(metrics_results, step=step)

            # Save per-prediction metrics
            results_storage_utils.save_per_prediction_results(output_dir,
                                                              step,
                                                              per_pred_results,
                                                              verbose=False)
            chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

        if config.get('only_eval', False):
            break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()
Beispiel #15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if config.batch_size % n_devices:
        raise ValueError(
            "Batch size must be divisible by the number of devices")

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=config.dataset_name,
        eval_dataset_name=config.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        vocab_path=vocab_path,
        target_vocab_size=config.vocab_size,
        batch_size=config.batch_size,
        max_corpus_chars=config.max_corpus_chars,
        max_length=config.max_target_length,
        max_eval_length=config.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    input_shape = (config.batch_size, config.max_target_length)
    target_shape = (config.batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    metrics_all = []
    with metric_writers.ensure_flushes(writer):
        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            # Shard data to devices and do a training step.
            batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
            optimizer, metrics, dropout_rngs = p_train_step(
                optimizer, batch, dropout_rng=dropout_rngs)
            metrics_all.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Save a checkpoint on one host after every checkpoint_freq steps.
            if (config.save_checkpoints and step % config.checkpoint_freq == 0
                    and step > 0 and jax.host_id() == 0):
                checkpoints.save_checkpoint(workdir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

            # Periodic metric handling.
            if step % config.eval_frequency != 0 and step > 0:
                continue

            # Training Metrics
            logging.info("Gathering training metrics.")
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop("learning_rate").mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop("denominator")
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary["learning_rate"] = lr
            summary = {"train_" + k: v for k, v in summary.items()}
            writer.write_scalars(step, summary)
            metrics_all = []

            # Eval Metrics
            logging.info("Gathering evaluation metrics.")
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop("denominator")
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            eval_summary = {"eval_" + k: v for k, v in eval_summary.items()}
            writer.write_scalars(step, eval_summary)

            # Translation and BLEU Score.
            logging.info("Translating evaluation dataset.")
            t_inference_start = time.time()
            sources, references, predictions = [], [], []
            for pred_batch in predict_ds:
                pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch["inputs"].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size),  # pylint: disable=cell-var-from-loop
                        pred_batch)
                pred_batch = common_utils.shard(pred_batch)
                cache = p_init_cache(pred_batch["inputs"])
                predicted = p_pred_step(pred_batch["inputs"], optimizer.target,
                                        cache, eos_id,
                                        config.max_predict_length)
                predicted = tohost(predicted)
                inputs = tohost(pred_batch["inputs"])
                targets = tohost(pred_batch["targets"])
                # Iterate through non-padding examples of batch.
                for i, s in enumerate(predicted[:cur_pred_batch_size]):
                    sources.append(decode_tokens(inputs[i]))
                    references.append(decode_tokens(targets[i]))
                    predictions.append(decode_tokens(s))
            logging.info(
                "Translation: %d predictions %d references %d sources.",
                len(predictions), len(references), len(sources))
            logging.info("Translation time: %.4f s step %d.",
                         time.time() - t_inference_start, step)

            # Calculate BLEU score for translated eval corpus against reference.
            bleu_matches = bleu.bleu_partial(references, predictions)
            all_bleu_matches = per_host_sum_pmap(bleu_matches)
            bleu_score = bleu.complete_bleu(*all_bleu_matches)
            # Save translation samples for tensorboard.
            exemplars = ""
            for n in np.random.choice(np.arange(len(predictions)), 8):
                exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
            writer.write_scalars(step, {"bleu": bleu_score})
            writer.write_texts(step, {"samples": exemplars})
Beispiel #16
0
def main(argv):
    del argv

    config = FLAGS.config
    output_dir = FLAGS.output_dir

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('dataset_dir'):
        logging.info('data_dir=%s', config.dataset_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    fillin = lambda *_: None
    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d dev per host (%d dev total), that is a %d per-device batch size.',
        batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=fillin(config.get('data_dir')))
    logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:])

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=fillin(data_dir))
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=fillin(data_dir))

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset, config.val_split, config.pp_eval,
                       config.get('dataset_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=fillin(config.get('data_dir')))
    steps_per_epoch = ntrain_img / batch_size

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info(
        'Running for %d steps, that means %f epochs and %f steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))

    # Specify Gaussian process layer configs.
    use_gp_layer = True
    gp_config = config.get('gp_layer', {})
    gp_layer_kwargs = get_gp_kwargs(gp_config)

    # Process ViT backbone model configs.
    vit_kwargs = config.get('model')

    het_kwargs = config.get('het')

    model = ub.models.vision_transformer_hetgp(
        num_classes=config.num_classes,
        use_gp_layer=use_gp_layer,
        vit_kwargs=vit_kwargs,
        gp_layer_kwargs=gp_layer_kwargs,
        multiclass=het_kwargs.multiclass,
        temperature=het_kwargs.temperature,
        mc_samples=het_kwargs.mc_samples,
        num_factors=het_kwargs.num_factors,
        param_efficient=het_kwargs.param_efficient)

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)

        rng, diag_noise_rng, standard_noise_rng = jax.random.split(rng, num=3)
        init_rngs = {
            'params': rng,
            'diag_noise_samples': diag_noise_rng,
            'standard_norm_noise_samples': standard_noise_rng
        }
        variables = model.init(init_rngs, dummy_input, train=False)
        # Split model parameters into trainable and untrainable collections.
        states, params = variables.pop('params')
        del variables

        # Set bias in the head to a low value, such that loss is small initially.
        params = flax.core.unfreeze(params)
        if use_gp_layer:
            # Modify the head parameter in the GP head.
            params['head']['loc_layer']['output_layer'][
                'bias'] = jnp.full_like(
                    params['head']['loc_layer']['output_layer']['bias'],
                    config.get('init_head_bias', 0))
        else:
            params['vit_backbone']['head']['bias'] = jnp.full_like(
                params['vit_backbone']['head']['bias'],
                config.get('init_head_bias', 0))

        return params, states

    (rng, rng_init, rng_dropout, diag_noise_rng,
     standard_noise_rng) = jax.random.split(rng, num=5)
    params_cpu, states_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, states, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        variable_dict = {'params': flax.core.freeze(params), **states}
        logits, out = model.apply(variable_dict,
                                  images,
                                  train=False,
                                  rngs={
                                      'dropout':
                                      rng_dropout,
                                      'diag_noise_samples':
                                      diag_noise_rng,
                                      'standard_norm_noise_samples':
                                      standard_noise_rng
                                  })

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :config.num_classes],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, states, images, labels, mask):
        variable_dict = {'params': flax.core.freeze(params), **states}
        logits, out = model.apply(variable_dict,
                                  images,
                                  train=False,
                                  rngs={
                                      'dropout':
                                      rng_dropout,
                                      'diag_noise_samples':
                                      diag_noise_rng,
                                      'standard_norm_noise_samples':
                                      standard_noise_rng
                                  })

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask, states):
        variable_dict = {'params': flax.core.freeze(params), **states}
        _, outputs = model.apply(variable_dict,
                                 images,
                                 train=False,
                                 rngs={
                                     'dropout':
                                     rng_dropout,
                                     'diag_noise_samples':
                                     diag_noise_rng,
                                     'standard_norm_noise_samples':
                                     standard_noise_rng
                                 })
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, states, lr, reset_covmat, images, labels, rng):
        """Update step."""
        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))
        rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split(
            rng_model_local, num=3)

        def loss_fn(params, states, images, labels):
            # Specify mutable collection to update untrainable GP parameters.
            variable_dict = {'params': flax.core.freeze(params), **states}
            model_results, updated_states = model.apply(
                variable_dict,
                images,
                train=True,
                rngs={
                    'dropout': rng_model_local,
                    'diag_noise_samples': diag_noise_rng,
                    'standard_norm_noise_samples': standard_noise_rng
                },
                mutable=list(states.keys()))

            logits, _ = model_results
            loss = getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)
            return loss, updated_states

        # Performs exact covariance update (i.e., reset precision matrix resetting
        # at begining of new epoch) if covmat_momentum is a null value.
        if gp_config.get('covmat_momentum', -1.) < 0:
            # Resets precision matrix to Identity * ridge_penalty if at the begining
            # of a new epoch. This should be done before accumulate gradient.
            ridge_penalty = gp_config.get('ridge_penalty', 1.)
            prec_mat_old = states['laplace_covariance']['head'][
                'covmat_layer']['precision_matrix']
            prec_mat_new = (
                (1. - reset_covmat) * prec_mat_old +
                reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty)

            states = flax.core.unfreeze(states)
            states['laplace_covariance']['head']['covmat_layer'][
                'precision_matrix'] = prec_mat_new
            states = flax.core.freeze(states)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        (l, s), g = train_utils.accumulate_gradient_with_states(
            jax.value_and_grad(loss_fn, has_aux=True), opt.target, states,
            images, labels, config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        do_grad_clip = config.get('grad_clip_norm', -1.) > 0.
        if config.get('grad_accum_steps', 1) == 1 or do_grad_clip:
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if do_grad_clip:
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)
        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, s, l, rng, measurements

    default_reinit_params = ('head/output_layer/kernel',
                             'head/output_layer/bias', 'head/kernel',
                             'head/bias')
    rng, train_loop_rngs = jax.random.split(rng)
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=train_loop_rngs,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=states_cpu,
        default_reinit_params=default_reinit_params,
        config=config)
    train_loop_rngs = checkpoint_data.train_loop_rngs
    opt_cpu = checkpoint_data.optimizer
    states_cpu = checkpoint_data.fixed_model_states
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    # Prepare the precision matrix resetting schedule, and pre-fetch it to device.
    reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0)
    reset_covmat_iter = train_utils.prefetch_scalar(
        map(reset_covmat_fn, range(first_step, total_steps)),
        nprefetch=config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)
    states_repl = flax_utils.replicate(states_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    if 'fewshot' in config:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)

    # Makes sure log_eval_steps is same as steps_per_epoch. This is because
    # the precision matrix needs to be updated fully (at the end of each epoch)
    # when eval takes place.
    log_eval_steps = steps_per_epoch

    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl, reset_covmat_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter,
            reset_covmat_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            # TODO(jereliu): Expand to allow precision matrix resetting.
            (opt_repl, states_repl, loss_value, train_loop_rngs,
             extra_measurements) = update_fn(opt_repl,
                                             states_repl,
                                             lr_repl,
                                             reset_covmat_repl,
                                             train_batch['image'],
                                             train_batch['labels'],
                                             rng=train_loop_rngs)

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if train_utils.itstime(step,
                               config.get('checkpoint_steps'),
                               total_steps,
                               process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            # For GP layer, we will also do the same for untrainable parameters
            # (`states`). This is ok since `random features` are frozen throughout
            # pre-training, and `precision matrix` is a finetuning-specific parameters
            # that will be re-learned in the finetuning task.
            opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
            states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                optimizer=opt_cpu,
                fixed_model_states=states_cpu,
                train_loop_rngs=train_loop_rngs,
                accumulated_train_time=accumulated_train_time)
            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if train_utils.itstime(step,
                               config.log_training_steps,
                               total_steps,
                               process=0):
            write_note('Reporting training progress...')
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, val_ds in val_ds_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                # TODO(jereliu): Extend to support soft multi-class probabilities.
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                val_iter = input_utils.start_input_pipeline(
                    val_ds, config.get('prefetch_to_device', 1))
                ncorrect, loss, nseen = 0, 0, 0
                for batch in val_iter:
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    states_repl,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, states_repl,
                                          batch['image'], batch['labels'],
                                          batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name],
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # There are two entries in the ood_ds dict (in-dist, ood), and that this
            # section computes metrics using both pieces. This is in contrast to
            # normal validation eval above where we eval metrics separately for each
            # val split in val_ds.
            if ood_ds and config.ood_methods:

                def make_sngp_eval_fn(states):
                    def sngp_eval_fn(params, images, labels, mask):
                        return evaluation_fn(params=params,
                                             states=states,
                                             images=images,
                                             labels=labels,
                                             mask=mask)

                    return sngp_eval_fn

                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds,
                    ood_ds_names,
                    config.ood_methods,
                    make_sngp_eval_fn(states_repl),
                    opt_repl.target,
                    n_prefetch=config.get('prefetch_to_device', 1))
                writer.write_scalars(step, ood_measurements)

            chrono.resume()

        if 'fewshot' in config:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target,
                    datasets=config.fewshot.datasets,
                    states=states_repl)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
Beispiel #17
0
def train_controller(
    controller,
    sim,
    pip_feed="parallel",  # or "sequential"
    mode="multipip",  # or "singular"
    duration=0.87,
    dt=0.03,
    epochs=100,
    use_noise=False,
    optimizer=optax.adamw,
    optimizer_params={
        "learning_rate": 1e-3,
        "weight_decay": 1e-4
    },
    loss_fn=lambda x, y: (jnp.abs(x - y)).mean(),
    scheduler="Cosine",
    tensorboard_dir=None,
    model_parameters={},  # used for tensorboard
    print_loss=1,
):
    """train controller."""
    peep = 5
    if mode == "multipip":
        pips = [10, 15, 20, 25, 30, 35]
    elif mode == "singular":
        pips = [35]

    # setup optimizer
    optim_params = copy.deepcopy(optimizer_params)
    if scheduler == "Cosine":
        if pip_feed == "parallel":
            steps_per_epoch = 1
        elif pip_feed == "sequential":
            steps_per_epoch = len(pips)
        decay_steps = int(epochs * steps_per_epoch)
        print("steps_per_epoch:" + str(steps_per_epoch))
        print("decay_steps:" + str(decay_steps))
        cosine_scheduler_fn = optax.cosine_decay_schedule(
            init_value=optim_params["learning_rate"], decay_steps=decay_steps)
        optim_params["learning_rate"] = cosine_scheduler_fn
        print("optim_params:" + str(optim_params))
        optim = optimizer(**optim_params)
    optim_state = optim.init(controller)

    # setup Tensorboard writer
    if tensorboard_dir is not None:
        trial_name = str(model_parameters)
        write_path = tensorboard_dir + trial_name
        summary_writer = metric_writers.create_default_writer(
            logdir=write_path, just_logging=jax.process_index() != 0)
        # summary_writer = tensorboard.SummaryWriter(write_path)
        summary_writer.write_hparams(model_parameters)

    tt = jnp.linspace(0, duration, int(duration / dt))
    losses = []
    for epoch in range(epochs):
        if pip_feed == "parallel":
            value, grad = jax.value_and_grad(rollout_parallel)(controller, sim,
                                                               tt, use_noise,
                                                               peep,
                                                               jnp.array(pips),
                                                               loss_fn)
            updates, optim_state = optim.update(grad, optim_state, controller)
            controller = optax.apply_updates(controller, updates)
            per_step_loss = value / len(tt)
            losses.append(per_step_loss)
            if epoch % print_loss == 0:
                # make new controller with trained parameters and normal clamp
                score = test_controller(controller, sim, pips, peep)
                print(f"Epoch: {epoch}\tLoss: {score:.2f}")
                if tensorboard_dir is not None:
                    summary_writer.write_scalars(epoch, {"score": score})
        if pip_feed == "sequential":
            for pip in pips:
                value, grad = jax.value_and_grad(rollout)(controller, sim, tt,
                                                          use_noise, peep,
                                                          pip, loss_fn,
                                                          jnp.array(0.))
                updates, optim_state = optim.update(grad, optim_state,
                                                    controller)
                controller = optax.apply_updates(controller, updates)
                per_step_loss = value / len(tt)
                losses.append(per_step_loss)
                if epoch % print_loss == 0:
                    # make new controller with trained parameters and normal clamp
                    score = test_controller(controller, sim, pips, peep)
                    print(f"Epoch: {epoch}, pip: {pip}\tLoss: {score:.2f}")
                    if tensorboard_dir is not None:
                        summary_writer.write_scalars(epoch,
                                                     {"per_step_loss": score})
    return controller, per_step_loss, score
Beispiel #18
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    Final TrainState.
  """

    writer = metric_writers.create_default_writer(
        logdir=workdir, just_logging=jax.process_index() != 0)

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.process_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder(config.dataset)
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = int(steps_per_epoch * config.num_epochs)
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    learning_rate_fn = create_learning_rate_fn(config, base_learning_rate,
                                               steps_per_epoch)

    state = create_train_state(rng, config, model, image_size,
                               learning_rate_fn)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    train_metrics = []
    hooks = []
    if jax.process_index() == 0:
        hooks += [
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics_last_t = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        for h in hooks:
            h(step)
        if step == step_offset:
            logging.info('Initial compilation completed.')

        if config.get('log_every_steps'):
            train_metrics.append(metrics)
            if (step + 1) % config.log_every_steps == 0:
                train_metrics = common_utils.get_metrics(train_metrics)
                summary = {
                    f'train_{k}': v
                    for k, v in jax.tree_map(lambda x: x.mean(),
                                             train_metrics).items()
                }
                summary['steps_per_second'] = config.log_every_steps / (
                    time.time() - train_metrics_last_t)
                writer.write_scalars(step + 1, summary)
                train_metrics = []
                train_metrics_last_t = time.time()

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            writer.write_scalars(
                step + 1, {f'eval_{key}': val
                           for key, val in summary.items()})
            writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

    return state
Beispiel #19
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    splits = input_pipeline.create_datasets(config, data_rng)
    num_classes = splits.info.features["label"].num_classes
    train_iter = iter(splits.train)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = splits.train.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)
    # We treat the learning rate in the config as the learning rate for batch size
    # 32 but scale it according to our batch size.
    global_batch_size = config.per_device_batch_size * jax.device_count()
    base_learning_rate = config.learning_rate * global_batch_size / 32.0
    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = create_train_state(
        config,
        model_rng,
        input_shape=splits.train.element_spec["input"].shape[1:],
        num_classes=num_classes)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir,
                                          {"train_iter": train_iter},
                                          max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Count number of trainable parameters. This must be done before replicating
    # the state to avoid double-counting replicated parameters.
    param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target))

    # Distribute training over local devices.
    state = flax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model=model,
        learning_rate_fn=learning_rate_fn,
        weight_decay=config.weight_decay),
                            axis_name=_PMAP_AXIS_NAME)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))
        # Log the number of trainable params.
        writer.write_scalars(initial_step, {"param_count": param_count})

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics = None
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps

            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics_update = p_train_step(state=state, batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            # When combining train and eval, we do not evaluate while training.
            if ((step % config.eval_every_steps == 0 or is_last_step)
                    and not config.combine_train_val_and_eval_on_test):
                with report_progress.timed("eval"):
                    eval_metrics = evaluate(model, state, splits.validation,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_metrics.compute())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                with report_progress.timed("checkpoint"):
                    ckpt.save(flax_utils.unreplicate(state))

            if is_last_step and config.combine_train_val_and_eval_on_test:
                # Evaluate a single time on the test set when requested.
                with report_progress.timed("test"):
                    test_metrics = evaluate(model, state, splits.test,
                                            config.num_eval_steps)
                writer.write_scalars(step, test_metrics.compute())

    logging.info("Finishing training at step %d", num_train_steps)
Beispiel #20
0
def train_simulator(
    dataset,
    model,
    num_boundary_models,
    activation_fn_name,
    R,
    C,
    # idx 0 to num_boundary_models-1 are boundary models,
    # idx num_boundary_models is default_model
    train_key="train",
    test_key="test",
    batch_size=512,
    epochs=500,
    optimizer=optax.adamw,
    optimizer_params={
        "learning_rate": 1e-3,
        "weight_decay": 1e-4
    },
    patience=10,
    lr_decay_factor=0.1,
    scheduler="ReduceLROnPlateau",  # or "Cosine"
    loss_fn=lambda x, y: (jnp.abs(x - y)).mean(),
    print_loss=10,
    use_tensorboard=False,
    mode="train",
    user_name="alexjyu-brain",
    tb_dir=None,
):
  """train simulator."""
  # evaluate on these at end of epoch
  for key in ["train", "test"]:
    dataset.data[key] = (jnp.array(dataset.data[key][0]),
                         jnp.array(dataset.data[key][1]))
  X_train, y_train = dataset.data[train_key]
  X_test, y_test = dataset.data[test_key]

  # set up optimizer and lr scheduler
  lr_mult = 1.0
  if scheduler == "ReduceLROnPlateau":
    optim = optimizer(**optimizer_params)
    patience_cnt = 0
    prev_loss = float("inf")
  elif scheduler == "Cosine":
    steps_per_epoch = float(X_train.shape[0] / batch_size)
    decay_steps = int((epochs + 1) * steps_per_epoch)
    logging.info("steps_per_epoch: %s", str(steps_per_epoch))
    logging.info("decay_steps: %s", str(decay_steps))
    cosine_scheduler_fn = optax.cosine_decay_schedule(
        init_value=optimizer_params["learning_rate"], decay_steps=decay_steps)
    optimizer_params["learning_rate"] = cosine_scheduler_fn
    logging.info("optimizer_params: %s", str(optimizer_params))
    optim = optimizer(**optimizer_params)
  optim_state = optim.init(model)

  loop_over_loader_partial = functools.partial(
      loop_over_loader, optim=optim, rollout_fn=rollout, scheduler=scheduler)

  # Tensorboard writer
  if use_tensorboard:
    config = copy.deepcopy(model.default_model_parameters)
    del config["activation_fn"]
    config["activation_fn_name"] = activation_fn_name

    if mode == "train":
      file_name = str(config)
    write_path = tb_dir + file_name
    summary_writer = metric_writers.create_default_writer(
        logdir=write_path, just_logging=jax.process_index() != 0)
    summary_writer = tensorboard.SummaryWriter(write_path)
    summary_writer.write_hparams(dict(config))

  # Main Training Loop
  prng_key = jax.random.PRNGKey(0)
  for epoch in range(epochs + 1):
    if epoch % 10 == 0:
      logging.info("epoch: %s", str(epoch))
    X, y, prng_key = get_shuffled_and_batched_data(dataset, batch_size,
                                                   train_key, prng_key)
    if epoch == 0:
      logging.info("X.shape: %s", str(X.shape))
      logging.info("y.shape: %s", str(y.shape))

    (model, optim_state, lr_mult,
     loss), _ = jax.lax.scan(loop_over_loader_partial,
                             (model, optim_state, lr_mult, 0.), (X, y))
    """for i in range(X.shape[0]):

      carry = (model, optim_state, lr_mult, 0.)
      carry, _ = loop_over_loader_partial(carry, (X[i], y[i]))
    model, optim_state, lr_mult, loss = carry
    """
    if scheduler == "ReduceLROnPlateau":
      if loss > prev_loss:
        patience_cnt = patience_cnt + 1
      else:
        patience_cnt = 0
      if patience_cnt == patience:
        lr_mult = lr_mult * lr_decay_factor
        patience_cnt = 0
      prev_loss = loss

    if epoch % print_loss == 0:
      if scheduler == "ReduceLROnPlateau":
        logging.info("loss: %s", str(loss))
        logging.info("prev_loss: %s", str(prev_loss))
        logging.info("patience_cnt: %s", str(patience_cnt))
        logging.info("lr_mult: %s", str(lr_mult))
      # expensive end-of-epoch eval, just for intuition
      train_loss = map_rollout_over_batch(model, (X_train, y_train), rollout)
      # cross-validation
      test_loss = map_rollout_over_batch(model, (X_test, y_test), rollout)

      if epoch % print_loss == 0:
        logging.info(
            f"Epoch {epoch:2d}: train={train_loss.item():.5f}, test_loss={test_loss.item():.5f}"
        )
        logging.info("-----------------------------------")
      if use_tensorboard:
        summary_writer.write_scalars(epoch, {"train_loss": train_loss})
        summary_writer.write_scalars(epoch, {"test_loss": test_loss})
  if use_tensorboard:
    summary_writer.flush()
  logging.info("finished looping over epochs")
  return model, test_loss
Beispiel #21
0
def main(config, output_dir):

    seed = config.get('seed', 0)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)
    tf.io.gfile.makedirs(output_dir)

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      num_epochs=1,
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)

        return val_ds

    val_ds_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('test_split'):
        val_ds_splits.update({
            'test':
            _get_val_split(config.dataset,
                           split=config.test_split,
                           pp_eval=config.pp_eval,
                           data_dir=config.get('data_dir'))
        })

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_ds_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_ds_splits['imagenet_real'] = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s',
                         config.get('ood_datasets'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    write_note('Initializing model...')
    logging.info('config.model = %s', config.model)
    model = ub.models.vision_transformer(num_classes=config.num_classes,
                                         **config.model)

    ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply)

    @functools.partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # params is a dict of the form:
        #   {'model_1': params_model_1, 'model_2': params_model_2, ...}
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        loss_as_str = config.get('loss', 'sigmoid_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)

        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, loss_as_str)(
            logits=ens_logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @functools.partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        loss_as_str = config.get('loss', 'softmax_xent')
        ens_logits, ens_prelogits = ensemble_pred_fn(params, images,
                                                     loss_as_str)
        label_indices = config.get('label_indices')
        if label_indices:
            ens_logits = ens_logits[:, label_indices]

        losses = getattr(train_utils, loss_as_str)(logits=ens_logits,
                                                   labels=labels,
                                                   reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [ens_logits, labels, ens_prelogits, mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @functools.partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        # Return shape [batch_size, representation_size * ensemble_size]. During
        # few-shot eval, a single linear regressor is applied over all dimensions.
        representation = []
        for p in params.values():
            _, outputs = model.apply({'params': flax.core.freeze(p)},
                                     images,
                                     train=False)
            representation += [outputs[config.fewshot.representation_layer]]
        representation = jnp.concatenate(representation, axis=1)
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    write_note('Load checkpoints...')
    ensemble_params = load_checkpoints(config)

    write_note('Replicating...')
    ensemble_params = flax.jax_utils.replicate(ensemble_params)

    if jax.process_index() == 0:
        writer.write_hparams(dict(config))

    write_note('Initializing few-shotters...')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}
    step = 1

    # Report validation performance.
    write_note('Evaluating on the validation set...')
    for val_name, val_ds in val_ds_splits.items():
        # Sets up evaluation metrics.
        ece_num_bins = config.get('ece_num_bins', 15)
        auc_num_bins = config.get('auc_num_bins', 1000)
        ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)
        calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)
        oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005,
                                                       num_bins=auc_num_bins)
        oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01,
                                                     num_bins=auc_num_bins)
        oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02,
                                                     num_bins=auc_num_bins)
        oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05,
                                                     num_bins=auc_num_bins)
        label_diversity = tf.keras.metrics.Mean()
        sample_diversity = tf.keras.metrics.Mean()
        ged = tf.keras.metrics.Mean()

        # Runs evaluation loop.
        val_iter = input_utils.start_input_pipeline(
            val_ds, config.get('prefetch_to_device', 1))
        ncorrect, loss, nseen = 0, 0, 0
        for batch in val_iter:
            if val_name == 'cifar_10h':
                batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                    cifar_10h_evaluation_fn(ensemble_params, batch['image'],
                                            batch['labels'], batch['mask']))
            else:
                batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                    evaluation_fn(ensemble_params, batch['image'],
                                  batch['labels'], batch['mask']))
            # All results are a replicated array shaped as follows:
            # (local_devices, per_device_batch_size, elem_shape...)
            # with each local device's entry being identical as they got psum'd.
            # So let's just take the first one to the host as numpy.
            ncorrect += np.sum(np.array(batch_ncorrect[0]))
            loss += np.sum(np.array(batch_losses[0]))
            nseen += np.sum(np.array(batch_n[0]))
            if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                # Here we parse batch_metric_args to compute uncertainty metrics.
                # (e.g., ECE or Calibration AUC).
                logits, labels, _, masks = batch_metric_args
                masks = np.array(masks[0], dtype=np.bool)
                logits = np.array(logits[0])
                probs = jax.nn.softmax(logits)
                # From one-hot to integer labels, as required by ECE.
                int_labels = np.argmax(np.array(labels[0]), axis=-1)
                int_preds = np.argmax(logits, axis=-1)
                confidence = np.max(probs, axis=-1)
                for p, c, l, d, m, label in zip(probs, confidence, int_labels,
                                                int_preds, masks, labels[0]):
                    ece.add_batch(p[m, :], label=l[m])
                    calib_auc.add_batch(d[m], label=l[m], confidence=c[m])
                    # TODO(jereliu): Extend to support soft multi-class probabilities.
                    oc_auc_0_5.add_batch(d[m],
                                         label=l[m],
                                         custom_binning_score=c[m])
                    oc_auc_1.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])
                    oc_auc_2.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])
                    oc_auc_5.add_batch(d[m],
                                       label=l[m],
                                       custom_binning_score=c[m])

                    if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                        batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                            label[m], p[m, :], config.num_classes)
                        label_diversity.update_state(batch_label_diversity)
                        sample_diversity.update_state(batch_sample_diversity)
                        ged.update_state(batch_ged)

        val_loss[val_name] = loss / nseen  # Keep for reproducibility tests.
        val_measurements = {
            f'{val_name}_prec@1': ncorrect / nseen,
            f'{val_name}_loss': val_loss[val_name],
        }
        if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
            val_measurements[f'{val_name}_ece'] = ece.result()['ece']
            val_measurements[f'{val_name}_calib_auc'] = calib_auc.result(
            )['calibration_auc']
            val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
            )['collaborative_auc']
            val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
            )['collaborative_auc']
        writer.write_scalars(step, val_measurements)

        if val_name == 'cifar_10h' or val_name == 'imagenet_real':
            cifar_10h_measurements = {
                f'{val_name}_label_diversity': label_diversity.result(),
                f'{val_name}_sample_diversity': sample_diversity.result(),
                f'{val_name}_ged': ged.result(),
            }
            writer.write_scalars(step, cifar_10h_measurements)

    # OOD eval
    # Entries in the ood_ds dict include:
    # (ind_dataset, ood_dataset1, ood_dataset2, ...).
    # OOD metrics are computed using ind_dataset paired with each of the
    # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
    # is also included in the ood_ds.
    if ood_ds and config.ood_methods:
        ood_measurements = ood_utils.eval_ood_metrics(ood_ds,
                                                      ood_ds_names,
                                                      config.ood_methods,
                                                      evaluation_fn,
                                                      ensemble_params,
                                                      n_prefetch=config.get(
                                                          'prefetch_to_device',
                                                          1))
        writer.write_scalars(step, ood_measurements)

    if 'fewshot' in config and fewshotter is not None:
        # Compute few-shot on-the-fly evaluation.
        write_note('Few-shot evaluation...')
        # Keep `results` to return for reproducibility tests.
        fewshot_results, best_l2 = fewshotter.run_all(ensemble_params,
                                                      config.fewshot.datasets)

        # TODO(dusenberrymw): Remove this once fewshot.py is updated.
        def make_writer_measure_fn(step):
            def writer_measure(name, value):
                writer.write_scalars(step, {name: value})

            return writer_measure

        fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results,
                                best_l2)

    write_note('Done!')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return val_loss, fewshot_results
Beispiel #22
0
def main(config, output_dir):
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=10)

    logging.info(config)

    acquisition_method = config.get('acquisition_method')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)
    writer.write_hparams(dict(config))

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note(f'Initializing for {acquisition_method}')

    # Download dataset
    data_builder = tfds.builder(config.dataset)
    data_builder.download_and_prepare()

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()

    val_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.val_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    test_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.test_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    # Init model
    if config.model_type == 'deterministic':
        model_utils = deterministic_utils
        reinit_params = config.get('model_reinit_params',
                                   ('head/kernel', 'head/bias'))
        model = ub.models.vision_transformer(num_classes=config.num_classes,
                                             **config.get('model', {}))
    elif config.model_type == 'batchensemble':
        model_utils = batchensemble_utils
        reinit_params = ('batchensemble_head/bias',
                         'batchensemble_head/kernel',
                         'batchensemble_head/fast_weight_alpha',
                         'batchensemble_head/fast_weight_gamma')
        model = ub.models.PatchTransformerBE(num_classes=config.num_classes,
                                             **config.model)
    else:
        raise ValueError('Expect config.model_type to be "deterministic" or'
                         f'"batchensemble", but received {config.model_type}.')

    init = model_utils.create_init(model, config, test_ds)

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    loaded_params = checkpoint_utils.load_checkpoint(tree=None,
                                                     path=config.model_init)
    loaded = checkpoint_utils.restore_from_pretrained_params(
        params_cpu,
        loaded_params,
        config.model.representation_size,
        config.model.classifier,
        reinit_params,
    )

    opt_cpu = opt_cpu.replace(target=loaded)

    # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being
    # donated otherwise. Ensure opt_cpu is really on the cpu this way.
    opt_cpu = jax.device_get(opt_cpu)

    update_fn = model_utils.create_update_fn(model, config)
    evaluation_fn = model_utils.create_evaluation_fn(model, config)

    # NOTE: We need this because we need an Id field of type int.
    # TODO(andreas): Rename to IdSubsetDatasetBuilder?
    pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder,
                                                             subset_ids=None)

    rng, pool_ds_rng = jax.random.split(rng)

    # NOTE: below line is necessary on multi host setup
    # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index())

    pool_train_ds = input_utils.get_data(
        dataset=pool_subset_data_builder,
        split=config.train_split,
        rng=pool_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        drop_remainder=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Don't repeat
    )

    # Potentially acquire an initial training set.
    initial_training_set_size = config.get('initial_training_set_size', 10)

    if initial_training_set_size > 0:
        current_opt_repl = flax_utils.replicate(opt_cpu)
        pool_ids, _, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            config=config)

        rng, initial_uniform_rng = jax.random.split(rng)
        pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng)

        initial_training_set_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=initial_training_set_size,
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=set(),
        )
    else:
        initial_training_set_batch_ids = []

    # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU
    # then this dataset creation could be simplified.
    # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340
    # CLU is explicitly not accepting outside contributions at the moment.
    train_subset_data_builder = al_utils.SubsetDatasetBuilder(
        data_builder, subset_ids=set(initial_training_set_batch_ids))

    test_accuracies = []
    training_sizes = []

    rng, rng_loop = jax.random.split(rng)
    rngs_loop = flax_utils.replicate(rng_loop)
    if config.model_type == 'batchensemble':
        rngs_loop = {'dropout': rngs_loop}

    # TODO(joost,andreas): double check if below is still necessary
    # (train_split is independent of this)
    # NOTE: train_ds_rng is re-used for all train_ds creations
    rng, train_ds_rng = jax.random.split(rng)

    measurements = {}
    accumulated_steps = 0
    while True:
        current_train_ds_length = len(train_subset_data_builder.subset_ids)
        if current_train_ds_length >= config.get('max_training_set_size', 150):
            break
        write_note(f'Training set size: {current_train_ds_length}')

        current_opt_repl = flax_utils.replicate(opt_cpu)

        # Only fine-tune if there is anything to fine-tune with.
        if current_train_ds_length > 0:
            # Repeat dataset to have oversampled epochs and bootstrap more batches
            number_of_batches = current_train_ds_length / config.batch_size
            num_repeats = math.ceil(config.total_steps / number_of_batches)
            write_note(f'Repeating dataset {num_repeats} times')

            # We repeat the dataset several times, such that we can obtain batches
            # of size batch_size, even at start of training. These batches will be
            # effectively 'bootstrap' sampled, meaning they are sampled with
            # replacement from the original training set.
            repeated_train_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_train,
                    available_ops=preprocess_utils.all_ops()),
                shuffle_buffer_size=config.shuffle_buffer_size,
                prefetch_size=config.get('prefetch_to_host', 2),
                # TODO(joost,andreas): double check if below leads to bootstrap
                # sampling.
                num_epochs=num_repeats,
            )

            # We use this dataset to evaluate how well we perform on the training set.
            # We need this to evaluate if we fit well within max_steps budget.
            train_eval_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_eval,
                    available_ops=preprocess_utils.all_ops()),
                shuffle=False,
                drop_remainder=False,
                prefetch_size=config.get('prefetch_to_host', 2),
                num_epochs=1,
            )

            # NOTE: warmup and decay are not a good fit for the small training set
            # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps,
            #                                                   **config.get('lr', {})
            #                                                   )
            lr_fn = lambda x: config.lr.base

            early_stopping_patience = config.get('early_stopping_patience', 15)
            current_opt_repl, rngs_loop, measurements = finetune(
                update_fn=update_fn,
                opt_repl=current_opt_repl,
                lr_fn=lr_fn,
                ds=repeated_train_ds,
                rngs_loop=rngs_loop,
                total_steps=config.total_steps,
                train_eval_ds=train_eval_ds,
                val_ds=val_ds,
                evaluation_fn=evaluation_fn,
                early_stopping_patience=early_stopping_patience,
                profiler=profiler)
            train_val_accuracies = measurements.pop('train_val_accuracies')
            current_steps = 0
            for step, train_acc, val_acc in train_val_accuracies:
                writer.write_scalars(accumulated_steps + step, {
                    'train_accuracy': train_acc,
                    'val_accuracy': val_acc
                })
                current_steps = step
            accumulated_steps += current_steps + 10

        test_accuracy = get_accuracy(evaluation_fn=evaluation_fn,
                                     opt_repl=current_opt_repl,
                                     ds=test_ds)

        write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}')

        test_accuracies.append(test_accuracy)
        training_sizes.append(current_train_ds_length)

        pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            use_pre_logits=acquisition_method == 'density',
            config=config)

        if acquisition_method == 'uniform':
            rng_loop, rng_acq = jax.random.split(rng_loop, 2)
            pool_scores = get_uniform_scores(pool_masks, rng_acq)
        elif acquisition_method == 'entropy':
            pool_scores = get_entropy_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'margin':
            pool_scores = get_margin_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'density':
            if current_train_ds_length > 0:
                pool_scores = get_density_scores(model=model,
                                                 opt_repl=current_opt_repl,
                                                 train_ds=train_eval_ds,
                                                 pool_pre_logits=pool_outputs,
                                                 pool_masks=pool_masks,
                                                 config=config)
            else:
                rng_loop, rng_acq = jax.random.split(rng_loop, 2)
                pool_scores = get_uniform_scores(pool_masks, rng_acq)
        else:
            raise ValueError('Acquisition method not found.')

        acquisition_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=config.get('acquisition_batch_size', 10),
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=train_subset_data_builder.subset_ids)

        train_subset_data_builder.subset_ids.update(acquisition_batch_ids)

        measurements.update({'test_accuracy': test_accuracy})
        writer.write_scalars(current_train_ds_length, measurements)

    write_note(f'Final acquired training ids: '
               f'{train_subset_data_builder.subset_ids}'
               f'Accuracies: {test_accuracies}')

    pool.close()
    pool.join()
    writer.close()
    # TODO(joost,andreas): save the final checkpoint
    return (train_subset_data_builder.subset_ids, test_accuracies)
Beispiel #23
0
def train_and_evaluate(config, workdir, strategy):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    strategy: Distribution strategy to use for distributing the model.
  """
    tf.io.gfile.makedirs(workdir)

    tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0),
                                                              2)
    tf.random.set_seed(tf_rng.numpy()[0])

    # Input pipeline.
    ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets(
        config, data_rng, strategy=strategy)
    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = (ds_info.splits["train"].num_examples //
                           config.global_batch_size * config.num_epochs)
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    # We treat the learning rate in the config as the learning rate for batch size
    # 256 but scale it according to our batch size.
    base_learning_rate = config.learning_rate * config.global_batch_size / 256.0

    # Initialize model.
    num_classes = ds_info.features["label"].num_classes

    if config.distill_teacher:
        do_distill = True
        teacher_file_list = (config.distill_teacher).split(",")
        teacher_models = load_teacher_models(teacher_file_list, num_classes,
                                             config, strategy)
        distill_params = {}
        distill_params["alpha"] = config.distill_alpha
        distill_params["beta"] = config.distill_fd_beta
        distill_params["teacher_model"] = TeacherModel(teacher_models,
                                                       name="teacher")
    else:
        do_distill = False
        distill_params = None

    state = create_state(config, num_classes=num_classes, strategy=strategy)

    ckpt_manager = tf.train.CheckpointManager(checkpoint=state,
                                              directory=workdir,
                                              max_to_keep=5)

    if ckpt_manager.latest_checkpoint:
        state.restore(ckpt_manager.latest_checkpoint)
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")
    initial_step = state.global_step.numpy().item()

    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    writer = metric_writers.create_default_writer(workdir)
    writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            state.model.trainable = True

            # `step` is a Python integer. `global_step` is a TF variable on the
            # GPU/TPU devices.
            is_last_step = step == num_train_steps

            train_step(state, train_iter, config.weight_decay,
                       learning_rate_fn, do_distill, distill_params, strategy)

            state.train_metrics.update_state_lr(
                learning_rate_fn(state.global_step.numpy().item()))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            report_progress(step)

            if step == initial_step:
                parameter_overview.log_parameter_overview(state.model)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, state.train_metrics.result())
                state.train_metrics.reset_states()
                state.train_metrics.reset_lr()

            if step % config.eval_every_steps == 0 or is_last_step:
                state.model.trainable = False
                if config.dataset == "imagenet-lt":
                    evaluate(state, val_ds, state.val_metrics, strategy)
                    writer.write_scalars(step, state.val_metrics.result())
                    logging.info("Num val images %d",
                                 state.val_metrics.accuracy.count.numpy())

                evaluate(state, test_ds, state.test_metrics, strategy)
                writer.write_scalars(step, state.test_metrics.result())

                logging.info("Num test images %d",
                             state.test_metrics.accuracy.count.numpy())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                checkpoint_path = ckpt_manager.save(step)
                logging.info("Saved checkpoint %s", checkpoint_path)

    logging.info("Finishing training at step %d", step)
    logging.info("Saving the final weights")
    file_path = "%s/final_weights" % workdir
    state.model.save_weights(file_path, save_format="tf")
Beispiel #24
0
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
  """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
  rng, model_rng = jax.random.split(rng)
  input_shape = tuple(train_ds.element_spec["image"].shape[1:])
  model, init_params, init_state = create_model(module, input_shape, model_rng)
  parameter_overview.log_parameter_overview(model.params)

  # Load a pretrained model parameters and state. Ignore the step and the
  # optimizer state in the checkpoint.
  pretrained_path = config.get("pretrained_checkpoint", "")
  if pretrained_path:
    logging.info("Load pretrained weights from '%s'", pretrained_path)
    state_dict = checkpoint.load_state_dict(pretrained_path)
    flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                              sep="/")
    model_state = state_dict["model_state"]

    # A prefix can be used to replace only a subpart of the network (e.g the
    # encoder). Prepend the prefix (if any) to model parameters and states.
    prefix = config.get("pretrained_prefix", "")
    if prefix:
      flatten_model_params = utils.add_prefix_to_dict_keys(
          flatten_model_params, f"{prefix}/")
      model_state = utils.add_prefix_to_dict_keys(
          model_state, f"/{prefix}")

    # Merge the params/state from the checkpoint into the initial params/state.
    flatten_init_params = utils.flatten_dict(init_params, sep="/")
    flatten_init_params, ignored_params = utils.override_dict(
        flatten_init_params, flatten_model_params)
    init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
    init_state, _ = utils.override_dict(init_state, model_state)

    if ignored_params:
      logging.warning("%d/%d parameters from the pretrained checkpoint "
                      "were ignored: %s", len(ignored_params),
                      len(flatten_init_params), ignored_params)

  optimizer_state = optimizer.init(init_params)

  state = TrainState(
      step=1,
      model_params=init_params,
      model_state=init_state,
      optimizer_state=optimizer_state)  # type: ignore
  # Do not keep a copy of the initial model.
  del init_params, init_state, optimizer_state

  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
  checkpoint_dir = os.path.join(workdir, "checkpoints")

  ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
  state = ckpt.restore_or_initialize(state)
  initial_step = int(state.step)
  # Replicate our parameters.
  state = flax.jax_utils.replicate(state)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  step_timer = utils.StepTimer(
      batch_size=config.batch_size, initial_step=initial_step)

  # Write config to the summary files. This makes the hyperparameters available
  # in TensorBoard and makes comparison of runs with tensorboard/ easier.
  if initial_step == 1:
    writer.write_hparams(utils.flatten_dict(config.to_dict()))

  # Generate per-device PRNG keys for the training loop.
  rng, train_rng = jax.random.split(rng)
  train_rngs = jax.random.split(train_rng, jax.local_device_count())

  # Generate per-device PRNG keys for model evaluation.
  rng, eval_rng = jax.random.split(rng)
  eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

  logging.info("Starting training loop at step %d.", initial_step)
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  train_metrics = utils.Means()

  do_eval_only = config.get("do_eval_only", False)
  if do_eval_only:
    config.num_train_steps = 1

  debug_enabled = config.get("debug", False)
  previous_grads = grads = None
  previous_updates = updates = None
  previous_state = None
  for step in range(initial_step, config.num_train_steps + 1):
    is_last_step = step == config.num_train_steps
    if debug_enabled:
      previous_grads = grads
      previous_updates = updates
      previous_state = state

    # Skip the training if only do the eval.
    if not do_eval_only:
      # Use ._numpy() to avoid copy.
      batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
      state, grads, updates, metrics, training_stats, train_rngs = train_step(
          state, batch, module, loss_fn, optimizer, train_metrics_dict,
          train_rngs)
      train_metrics.append(flax.jax_utils.unreplicate(metrics))

      # Update topk temperature with linearly decreasing schedule if enabled.
      if (config.get("linear_decrease_perturbed_sigma", False) and
          config.get("selection_method", "") == "perturbed-topk"):

        model_state = state.model_state.as_dict()

        if "/PatchNet_0" in model_state:
          net_str = "/PatchNet_0"
        else:
          net_str = "/"

        progress = step / config.num_train_steps
        sigma_multiplier = 1. - progress
        previous_mult = model_state[net_str]["sigma_mutiplier"]
        sigma_multiplier = sigma_multiplier + jnp.zeros_like(previous_mult)
        model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
        state = state.replace(model_state=nn.Collection(model_state))

      if debug_enabled:
        if utils.has_any_inf_or_nan(metrics):
          # Save checkpoint
          if previous_state:
            ckpt.save(flax.jax_utils.unreplicate(previous_state))
          ckpt.save(flax.jax_utils.unreplicate(state))

          # Log gradients and updates.
          if previous_grads or previous_updates:
            write_gradient_histogram(writer, step,
                                     grads=previous_grads,
                                     updates=previous_updates)
          write_gradient_histogram(writer, step + 1,
                                   grads=grads, updates=updates)

          raise RuntimeError("A training metric took an invalid value: "
                             f"{metrics}.")

      logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step)
      report_progress(step)

      if step % config.log_loss_every_steps == 0 or is_last_step:
        results = train_metrics.result()
        writer.write_scalars(step, results)
        writer.write_scalars(step, step_timer.get_and_reset(step))
        if utils.has_any_inf_or_nan(results):
          raise ValueError("A training metric took an invalid value.")
        train_metrics.reset()

    if (step % config.checkpoint_every_steps == 0 or is_last_step):
      with step_timer.paused():
        ckpt.save(flax.jax_utils.unreplicate(state))

    # Evaluation
    if step % config.eval_every_steps == 0 or is_last_step:
      with step_timer.paused():
        eval_metrics, first_batch_stats, eval_rngs = evaluate(
            state, module, eval_ds, eval_metrics_dict, eval_rngs)

      if jax.host_id() == 0:
        log_histograms = config.get("log_histograms", False)
        log_images = config.get("log_images", True)
        # Log the last gradients and updates histograms.
        if not do_eval_only:
          write_stats_results(writer, step, training_stats, stats_aggregators,
                              prefix="train/", log_images=log_images)
          if log_histograms:
            write_gradient_histogram(writer, step, grads=grads, updates=updates)

        write_stats_results(writer, step, first_batch_stats,
                            stats_aggregators, prefix="eval/",
                            log_images=log_images)

        # write patch representation histograms
        if (log_histograms and first_batch_stats and
            "patch_representations" in first_batch_stats):
          patch_representations = first_batch_stats["patch_representations"]
          writer.write_histograms(step, {
              "patch_representations": patch_representations
          })

        if eval_metrics:
          writer.write_scalars(step, eval_metrics)

  writer.flush()
  return state
Beispiel #25
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  def encode_strings(strs, max_len):
    tokenized_batch = np.zeros((len(strs), max_len), np.int32)
    for i, s in enumerate(strs):
      toks = encoder.tokenize(s).numpy()
      # Remove EOS token in prompt.
      tokenized_batch[i, :toks.shape[0]-1] = toks[:-1]
    return tokenized_batch

  tokenized_prompts = encode_strings(
      [config.prompts], config.max_predict_length)

  logging.info("Initializing model, optimizer, and step functions.")
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  rng, inference_rng = random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.TransformerLM(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")

  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config,
          temperature=config.sampling_temperature,
          top_k=config.sampling_top_k),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary["perplexity"] = jnp.clip(
              jnp.exp(summary["loss"]), a_max=1.0e4)
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          # (clipped) perplexity after averaging log-perplexitie
          eval_results["perplexity"] = jnp.clip(
              jnp.exp(eval_results["loss"]), a_max=1.0e4)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("generate_text"):
          exemplars = generate_prediction(
              p_pred_step=p_pred_step,
              target=optimizer.target,
              tokenized_prompts=tokenized_prompts,
              eos_id=eos_id,
              inference_rng=inference_rng,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
def train(*,
          workdir,
          initial_step,
          chkpt_manager,
          Phi,
          Psi,
          optimal_subspace,
          num_epochs,
          learning_rate,
          key,
          method,
          lissa_kappa,
          optimizer,
          covariance_batch_size,
          main_batch_size,
          weight_batch_size,
          estimate_feature_norm=True):
    """Training function.

  For lissa, the total number of samples is
  2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size.

  Args:
    workdir: Work directory, where we'll save logs.
    initial_step: Initial step
    chkpt_manager: Checkpoint manager.
    Phi: The initial feature matrix.
    Psi: The target matrix whose PCA is to be determined.
    optimal_subspace: Top-d left singular vectors of Psi.
    num_epochs: How many gradient steps to perform. (Not really epochs)
    learning_rate: The step size parameter for sgd.
    key: The jax prng key.
    method: 'naive', 'lissa', or 'oracle'.
    lissa_kappa: The parameter of the lissa method, if used.
    optimizer: Which optimizer to use. Only 'sgd' is supported.
    covariance_batch_size: the 'J' parameter. For the naive method, this is how
      many states we sample to construct the inverse. For the lissa method,
      ditto -- these are also "iterations".
    main_batch_size: How many states to update at once.
    weight_batch_size: How many states to construct the weight vector.
    estimate_feature_norm: Whether to use a running average of the max feature
      norm rather than the real maximum.

  Returns:
    tuple: representation and gradient arrays
  """
    # Don't overwrite Phi.
    Phi = np.copy(Phi)
    Phis = [np.copy(Phi)]

    num_states, d = Phi.shape
    _, num_tasks = Psi.shape

    # Keep a running average of the max norm of a feature vector. None means:
    # don't do it.
    if estimate_feature_norm:
        estimated_feature_norm = utils.compute_max_feature_norm(Phi)
    else:
        estimated_feature_norm = None

    # Create an explicit weight vector (needed for explicit method).
    key, weight_key = jax.random.split(key)
    explicit_weight_matrix = np.array(
        jax.random.normal(  # charlinel(why benefit of np?)
            weight_key, (d, num_tasks),
            dtype=jnp.float64))

    assert optimizer == 'sgd', 'Non-sgd not yet supported.'

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    hooks = [
        periodic_actions.PeriodicCallback(
            every_steps=5_000,
            callback_fn=lambda step, t: chkpt_manager.save((step, Phi)))
    ]
Beispiel #27
0
def train_and_evaluate(config, workdir):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  is_first_process = jax.process_index() == 0
  tf.io.gfile.makedirs(workdir)

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
      config)
  config.seq_length = 250
  vocab_size = int(encoder.vocab_size())
  config.num_classes = vocab_size
  config.data_shape = (config.seq_length, 1)

  logging.info('Training with vocab size %d', vocab_size)

  def decode_tokens(toks):
    return encoder.detokenize(toks)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  config.per_device_batch_size = config.batch_size // jax.process_count()

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  model, initial_variables = model_setup(init_rng, config)

  # Instead of passing the optimizer fns directly, we use a fn that returns
  # the optimizer given a learning rate.
  def tx_fn(lr):
    return optax.adamw(
        lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0,
        weight_decay=config.weight_decay)

  state = language_train_state.TrainState.create(
      params=initial_variables['params'], tx_fn=tx_fn)

  # We access model params only from state below via state.params.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated model state from last checkpoint.
    state = checkpoints.restore_checkpoint(workdir, state)
    # Grab last step.
    start_step = int(state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=not is_first_process)
  if start_step == 0:
    config_dict = dict(config)
    writer.write_hparams(config_dict)

  if is_first_process and start_step == 0:
    # Dump config file to work dir for easy model loading.
    config_path = os.path.join(workdir, 'config')
    with tf.io.gfile.GFile(config_path, 'wb') as fp:
      pickle.dump(config, fp)

  print('Using state', type(state))
  # Replicate state.
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      factors=config.lr_factors,
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # Compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model=model,
          learning_rate_fn=learning_rate_fn,
          clip_grad=config.clip_grad,
          ema_momentum=config.get('ema_momentum', 0.999)),
      axis_name='batch',
      in_axes=(0, 0),
      donate_argnums=(0,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, model=model),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of train PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rng = jax.random.fold_in(rng, jax.process_index())
  rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5)
  train_rngs = jax.random.split(rng1, jax.local_device_count())
  eval_rngs = jax.random.split(rng2, jax.local_device_count())
  test_rngs = jax.random.split(rng3, jax.local_device_count())
  del rng, rng1, rng2, rng3

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if is_first_process:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []

  # Iterator that does epoch-wise indefinite iteration.
  def iterate_train(train_ds):
    epoch = 1
    while True:
      msg = f'Starting epoch {epoch}'
      logging.info(msg)
      for batch in train_ds:
        yield batch
      epoch += 1

  train_iter = iterate_train(train_ds)

  kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
  kl_history = []

  with metric_writers.ensure_flushes(writer):
    step = start_step
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        state, metrics = p_train_step(
            state, batch, rng=train_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step > 0 and (step % config.eval_every_steps == 0 or is_last_step):
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)

          # First handle loss terms per step.
          t_batch = train_metrics.pop('t_batch')
          nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch')
          kl_tracker_train.update(
              t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1))
          kl_values = kl_tracker_train.get_kl_per_t()
          kl_history.append(np.array(kl_values))
          kl_history = kl_history[-100:]  # Keep last 100 items only.

          # Handle remaining `standard` metrics
          summary = jax.tree_map(jnp.mean, train_metrics)
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results, eval_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=eval_ds,
              rng=eval_rngs)
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

          test_results, test_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=test_ds,
              rng=test_rngs)
          writer.write_scalars(
              step, {'test_' + k: v for k, v in test_results.items()})

        if step == 1000 or (step > 0 and
                            step % config.detailed_eval_every_steps == 0):
          if is_first_process:
            loss_components_path = os.path.join(workdir, 'loss_components')
            with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
              pickle.dump(kl_history[-1], fp)

          extensive_eval_rngs = extensive_eval(
              config, extensive_eval_rngs, writer, workdir,
              model, state, kl_history, test_ds, step,
              decode_tokens)

        with report_progress.timed('generate_text'):
          generate_prediction(sample_rng, config, model, state, writer,
                              decode_tokens, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step > 0 and
          (step % config.checkpoint_every_steps == 0 or is_last_step))
      if config.save_checkpoints and save_checkpoint and is_first_process:
        with report_progress.timed('checkpoint'):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state),
                                      step, overwrite=True)
Beispiel #28
0
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    # Seed for reproducibility.
    rng = jax.random.PRNGKey(config.rng_seed)

    # Set up logging.
    summary_writer = metric_writers.create_default_writer(workdir)
    summary_writer.write_hparams(dict(config))

    # Get datasets.
    rng, dataset_rng = jax.random.split(rng)
    dataset = input_pipeline.get_dataset(config, dataset_rng)
    graph, labels, masks = jax.tree_map(jnp.asarray, dataset)
    labels = jax.nn.one_hot(labels, config.num_classes)
    train_mask = masks['train']
    train_indices = jnp.where(train_mask)[0]
    train_labels = labels[train_indices]
    num_training_nodes = len(train_indices)

    # Get subgraphs.
    if config.differentially_private_training:
        graph = jax.tree_map(np.asarray, graph)
        subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to)
        graph = jax.tree_map(jnp.asarray, graph)

        # We only need the subgraphs for training nodes.
        train_subgraphs = subgraphs[train_indices]
        del subgraphs
    else:
        train_subgraphs = None

    # Initialize privacy accountant.
    training_privacy_accountant = privacy_accountants.get_training_privacy_accountant(
        config, num_training_nodes, compute_max_terms_per_node(config))

    # Construct and initialize model.
    rng, init_rng = jax.random.split(rng)
    estimation_indices = get_estimation_indices(train_indices, config)
    state = create_train_state(init_rng, config, graph, train_labels,
                               train_subgraphs, estimation_indices)

    # Set up checkpointing of the model.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Log overview of parameters.
    parameter_overview.log_parameter_overview(state.params)

    # Log metrics after initialization.
    logits = compute_logits(state, graph)
    metrics_after_init = compute_metrics(logits, labels, masks)
    metrics_after_init['epsilon'] = 0
    log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init')

    # Train model.
    rng, train_rng = jax.random.split(rng)
    max_training_epsilon = get_max_training_epsilon(config)

    # Hooks called periodically during training.
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_training_steps, writer=summary_writer)
    profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
    hooks = [report_progress, profiler]

    for step in range(initial_step, config.num_training_steps):

        # Perform one step of training.
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
            # Sample batch.
            step_rng = jax.random.fold_in(train_rng, step)
            indices = jax.random.choice(step_rng, num_training_nodes,
                                        (config.batch_size, ))

            # Compute gradients.
            if config.differentially_private_training:
                grads = compute_updates_for_dp(state, graph, train_labels,
                                               train_subgraphs, indices,
                                               config.adjacency_normalization)
            else:
                grads = compute_updates(state, graph, train_labels, indices)

            # Update parameters.
            state = update_model(state, grads)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 10,
                            step)
        for hook in hooks:
            hook(step)

        # Evaluate, if required.
        is_last_step = (step == config.num_training_steps - 1)
        if step % config.evaluate_every_steps == 0 or is_last_step:
            with report_progress.timed('eval'):
                # Check if privacy budget exhausted.
                training_epsilon = training_privacy_accountant(step + 1)
                if max_training_epsilon is not None and training_epsilon >= max_training_epsilon:
                    break

                # Compute metrics.
                logits = compute_logits(state, graph)
                metrics_during_training = compute_metrics(
                    logits, labels, masks)
                metrics_during_training['epsilon'] = training_epsilon
                log_metrics(step, metrics_during_training, summary_writer)

        # Checkpoint, if required.
        if step % config.checkpoint_every_steps == 0 or is_last_step:
            with report_progress.timed('checkpoint'):
                ckpt.save(state)

    return state
Beispiel #29
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    msg = f'Running with seed {config.seed}.'
    logging.info(msg)
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    train_ds, test_ds, shape, num_classes = datasets.get_dataset(
        config, data_rng)

    # config.mask_shape = mask_shape
    config.data_shape = shape
    config.num_classes = num_classes

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    # Create output directory for saving samples.
    output_path = work_dir
    tf.io.gfile.makedirs(output_path)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    tx = optax.adam(config.learning_rate,
                    b1=0.9,
                    b2=config.beta2,
                    eps=1e-08,
                    eps_root=0.0)
    state = custom_train_state.TrainState.create(params=variables['params'],
                                                 tx=tx)

    if try_checkpoint:
        state, start_epoch = checkpoint.restore_from_path(work_dir, state)
        if start_epoch is None:
            start_epoch = 1
    else:
        # For debugging we start at zero, so we immediately do detailed eval.
        start_epoch = 0

    if is_first_host and start_epoch == 1:
        config_dict = dict(config)
        writer.write_hparams(config_dict)

    if is_first_host and start_epoch in (0, 1):
        # Dump config file to work dir for easy model loading.
        config_path = os.path.join(work_dir, 'config')
        with tf.io.gfile.GFile(config_path, 'wb') as fp:
            pickle.dump(config, fp)

    test_rng, train_rng = jax.random.split(rng)

    kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    p_train_step = jax.pmap(functools.partial(train_step,
                                              model=model,
                                              config=config),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, None))

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    with metric_writers.ensure_flushes(writer):
        for epoch in range(start_epoch, config.num_epochs + 1):
            # Train part.
            state, train_metrics, train_rng = train_epoch(
                p_train_step, state, train_ds, config.batch_size, epoch,
                train_rng, kl_tracker_train)

            # Val part.
            eval_metrics, test_rng = eval_model(p_eval_step, test_rng, state,
                                                test_ds, epoch)

            # Metric logging.
            if is_first_host:
                log_standard_metrics(writer, train_metrics, eval_metrics,
                                     epoch)

            kl_values = kl_tracker_train.get_kl_per_t()
            kl_history.append(np.array(kl_values))

            # Prune to avoid too much memory consumption.
            kl_history = kl_history[-50:]

            if epoch == 15 or epoch % config.detailed_eval_every == 0:
                if is_first_host:
                    loss_components_path = os.path.join(
                        work_dir, 'loss_components')
                    with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
                        pickle.dump(kl_history[-1], fp)

                test_rng = extensive_eval(config, test_rng, writer,
                                          output_path, model, state,
                                          kl_history, test_ds, epoch)

            # Save to checkpoint.
            if is_first_host and epoch % config.save_every == 0:
                # Save to epoch + 1 since current epoch has just been completed.
                logging.info('saving checkpoint')
                checkpoint.save_checkpoint(
                    work_dir,
                    state=flax.jax_utils.unreplicate(state),
                    step=epoch + 1,
                    keep=2)
                logging.info('finished saving checkpoint')

        return state
Beispiel #30
0
def main(_):
    if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching:
        raise ValueError(
            "'precrop_iters has no effect when 'batching' the dataset")
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0

    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.model_dir, "model_dir")

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5)
    rngs = common_utils.shard_prng_key(step_rng)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir,
        FLAGS.config,
        rng=data_rng,
        num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    *_, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets
    iter_render_ds = zip(range(num_poses), render_ds)
    iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds)
    iter_test_ds = zip(range(test_items), test_ds)
    img_h, img_w, _ = hwf

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)

    # cycle already seen examples if resuming from checkpoint
    # (only useful for ensuring deterministic dataset, slow for large start_step)
    # if start_step != 0:
    #     for _ in range(start_step):
    #         _ = next(train_ds)

    # parameter_overview.log_parameter_overview(state.optimizer_coarse.target)
    # if FLAGS.config.num_importance > 0:
    #     parameter_overview.log_parameter_overview(state.optimizer_fine.target)

    state = jax.device_put_replicated(state, jax.local_devices())

    ### Build "pmapped" functions for distributed training
    train_fn = functools.partial(train_step, near, far, FLAGS.config,
                                 schedule_fn)
    p_train_step = jax.pmap(
        train_fn,
        axis_name="batch",
        in_axes=(0, 0, None, 0),
        # donate_argnums=(0, 1, 2),
    )

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    # TODO: add hparams
    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)
    logging.info("Starting training loop.")

    hooks = []
    profiler = periodic_actions.Profile(num_profile_steps=5,
                                        logdir=FLAGS.model_dir)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.config.num_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [profiler, report_progress]
    train_metrics = []
    gen_video_ = functools.partial(gen_video, FLAGS.model_dir)

    for step in range(start_step, FLAGS.config.num_steps + 1):
        is_last_step = step == FLAGS.config.num_steps

        batch = next(train_ds)
        coords = None
        if not FLAGS.config.batching:
            coords = jnp.meshgrid(jnp.arange(img_h),
                                  jnp.arange(img_w),
                                  indexing="ij")
            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])

        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            state, metrics = p_train_step(batch, state, coords, rngs)
        train_metrics.append(metrics)

        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)
        _ = [h(step) for h in hooks]

        ### Write train summaries to TB
        if step % FLAGS.config.i_print == 0 or is_last_step:
            with report_progress.timed("training_metrics"):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
                summary = {f"train/{k}": v for k, v in train_summary.items()}
                writer.write_scalars(step, summary)
            train_metrics = []

        ### Eval a random validation image and plot it to TB
        if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step:
            with report_progress.timed("validation"):
                inputs = next(val_ds)
                rays, padding = prepare_render_data(inputs["rays"]._numpy())
                outputs = p_eval_step(state, rays)
                preds, preds_c, z_std = jax.tree_map(
                    lambda x: to_np(x, hwf, padding), outputs)
                loss = np.mean((preds["rgb"] - inputs["image"])**2)
                summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)}
                writer.write_scalars(step, summary)

                summary = {
                    "val/rgb": to_rgb(preds["rgb"]),
                    "val/target": to_np(inputs["image"], hwf, padding),
                    "val/disp": disp_post(preds["disp"], FLAGS.config),
                    "val/acc": preds["acc"],
                }
                if FLAGS.config.num_importance > 0:
                    summary["val/rgb_c"] = to_rgb(preds_c["rgb"])
                    summary["val/disp_c"] = disp_post(preds_c["disp"],
                                                      FLAGS.config)
                    summary["val/z_std"] = z_std
                writer.write_images(step, summary)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            with report_progress.timed("video_render"):
                logging.info("Rendering video at step %d", step)
                rgb_list = []
                disp_list = []
                for idx, inputs in tqdm(iter_render_ds, desc="Rays render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding),
                                         preds)
                    rgb_list.append(preds["rgb"])
                    disp_list.append(preds["disp"])

                gen_video_(np.stack(rgb_list), "rgb", r_hwf, step)
                disp = np.stack(disp_list)
                gen_video_(disp_post(disp, FLAGS.config),
                           "disp",
                           r_hwf,
                           step,
                           ch=1)

                if FLAGS.config.use_viewdirs:
                    rgb_list = []
                    for idx, inputs in tqdm(iter_vdirs_ds,
                                            desc="Viewdirs render"):
                        rays, padding = prepare_render_data(
                            inputs["rays"].numpy())
                        preds, *_ = p_eval_step(state, rays)
                        rgb_list.append(to_np(preds["rgb"], r_hwf, padding))
                    gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            with report_progress.timed("test_render"):
                logging.info("Rendering test set at step %d", step)
                test_losses = []
                for idx, inputs in tqdm(iter_test_ds, desc="Test render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step,
                                   idx)

                    if FLAGS.config.render_factor == 0:
                        loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                        test_losses.append(loss)
                if FLAGS.config.render_factor == 0:
                    loss = np.mean(test_losses)
                    summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)}
                    writer.write_scalars(step, summary)
        writer.flush()

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 or is_last_step:
            with report_progress.timed("checkpoint"):
                save_checkpoint(state, FLAGS.model_dir)