def test_load_state_dict(self):
   base_dir = tempfile.mkdtemp()
   state = TrainState(step=1)
   ckpt = checkpoint.Checkpoint(base_dir)
   # Initializes.
   state = ckpt.restore_or_initialize(state)
   # Load via load_state_dict().
   flax_dict = checkpoint.load_state_dict(base_dir)
   self.assertEqual(flax_dict, dict(step=1))
   with self.assertRaisesRegexp(FileNotFoundError, r"^No checkpoint found"):
     checkpoint.load_state_dict(tempfile.mkdtemp())
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
    """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
    rng, model_rng = jax.random.split(rng)
    input_shape = tuple(train_ds.element_spec["image"].shape[1:])
    model, init_params, init_state = create_model(module, input_shape,
                                                  model_rng)
    parameter_overview.log_parameter_overview(model.params)

    # Load a pretrained model parameters and state. Ignore the step and the
    # optimizer state in the checkpoint.
    pretrained_path = config.get("pretrained_checkpoint", "")
    if pretrained_path:
        logging.info("Load pretrained weights from '%s'", pretrained_path)
        state_dict = checkpoint.load_state_dict(pretrained_path)
        flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                                  sep="/")
        model_state = state_dict["model_state"]

        # A prefix can be used to replace only a subpart of the network (e.g the
        # encoder). Prepend the prefix (if any) to model parameters and states.
        prefix = config.get("pretrained_prefix", "")
        if prefix:
            flatten_model_params = utils.add_prefix_to_dict_keys(
                flatten_model_params, f"{prefix}/")
            model_state = utils.add_prefix_to_dict_keys(
                model_state, f"/{prefix}")

        # Merge the params/state from the checkpoint into the initial params/state.
        flatten_init_params = utils.flatten_dict(init_params, sep="/")
        flatten_init_params, ignored_params = utils.override_dict(
            flatten_init_params, flatten_model_params)
        init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
        init_state, _ = utils.override_dict(init_state, model_state)

        if ignored_params:
            logging.warning(
                "%d/%d parameters from the pretrained checkpoint "
                "were ignored: %s", len(ignored_params),
                len(flatten_init_params), ignored_params)

    optimizer_state = optimizer.init(init_params)

    state = TrainState(step=1,
                       model_params=init_params,
                       model_state=init_state,
                       optimizer_state=optimizer_state)  # type: ignore
    # Do not keep a copy of the initial model.
    del init_params, init_state, optimizer_state

    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    checkpoint_dir = os.path.join(workdir, "checkpoints")

    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step)
    # Replicate our parameters.
    state = flax.jax_utils.replicate(state)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    step_timer = utils.StepTimer(batch_size=config.batch_size,
                                 initial_step=initial_step)

    # Write config to the summary files. This makes the hyperparameters available
    # in TensorBoard and makes comparison of runs with tensorboard/ easier.
    if initial_step == 1:
        writer.write_hparams(utils.flatten_dict(config.to_dict()))

    # Generate per-device PRNG keys for the training loop.
    rng, train_rng = jax.random.split(rng)
    train_rngs = jax.random.split(train_rng, jax.local_device_count())

    # Generate per-device PRNG keys for model evaluation.
    rng, eval_rng = jax.random.split(rng)
    eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    train_metrics = utils.Means()

    do_eval_only = config.get("do_eval_only", False)
    if do_eval_only:
        config.num_train_steps = 1

    debug_enabled = config.get("debug", False)
    previous_grads = grads = None
    previous_updates = updates = None
    previous_state = None
    for step in range(initial_step, config.num_train_steps + 1):
        is_last_step = step == config.num_train_steps
        if debug_enabled:
            previous_grads = grads
            previous_updates = updates
            previous_state = state

        # Skip the training if only do the eval.
        if not do_eval_only:
            # Use ._numpy() to avoid copy.
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            state, grads, updates, metrics, training_stats, train_rngs = train_step(
                state, batch, module, loss_fn, optimizer, train_metrics_dict,
                train_rngs)
            train_metrics.append(flax.jax_utils.unreplicate(metrics))

            # Update topk temperature with linearly decreasing schedule if enabled.
            if (config.get("linear_decrease_perturbed_sigma", False) and
                    config.get("selection_method", "") == "perturbed-topk"):

                model_state = state.model_state.as_dict()

                if "/PatchNet_0" in model_state:
                    net_str = "/PatchNet_0"
                else:
                    net_str = "/"

                progress = step / config.num_train_steps
                sigma_multiplier = 1. - progress
                previous_mult = model_state[net_str]["sigma_mutiplier"]
                sigma_multiplier = sigma_multiplier + jnp.zeros_like(
                    previous_mult)
                model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
                state = state.replace(model_state=nn.Collection(model_state))

            if debug_enabled:
                if utils.has_any_inf_or_nan(metrics):
                    # Save checkpoint
                    if previous_state:
                        ckpt.save(flax.jax_utils.unreplicate(previous_state))
                    ckpt.save(flax.jax_utils.unreplicate(state))

                    # Log gradients and updates.
                    if previous_grads or previous_updates:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=previous_grads,
                                                 updates=previous_updates)
                    write_gradient_histogram(writer,
                                             step + 1,
                                             grads=grads,
                                             updates=updates)

                    raise RuntimeError(
                        "A training metric took an invalid value: "
                        f"{metrics}.")

            logging.log_first_n(logging.INFO, "Finished training step %d.", 3,
                                step)
            report_progress(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                results = train_metrics.result()
                writer.write_scalars(step, results)
                writer.write_scalars(step, step_timer.get_and_reset(step))
                if utils.has_any_inf_or_nan(results):
                    raise ValueError(
                        "A training metric took an invalid value.")
                train_metrics.reset()

        if (step % config.checkpoint_every_steps == 0 or is_last_step):
            with step_timer.paused():
                ckpt.save(flax.jax_utils.unreplicate(state))

        # Evaluation
        if step % config.eval_every_steps == 0 or is_last_step:
            with step_timer.paused():
                eval_metrics, first_batch_stats, eval_rngs = evaluate(
                    state, module, eval_ds, eval_metrics_dict, eval_rngs)

            if jax.host_id() == 0:
                log_histograms = config.get("log_histograms", False)
                log_images = config.get("log_images", True)
                # Log the last gradients and updates histograms.
                if not do_eval_only:
                    write_stats_results(writer,
                                        step,
                                        training_stats,
                                        stats_aggregators,
                                        prefix="train/",
                                        log_images=log_images)
                    if log_histograms:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=grads,
                                                 updates=updates)

                write_stats_results(writer,
                                    step,
                                    first_batch_stats,
                                    stats_aggregators,
                                    prefix="eval/",
                                    log_images=log_images)

                # write patch representation histograms
                if (log_histograms and first_batch_stats
                        and "patch_representations" in first_batch_stats):
                    patch_representations = first_batch_stats[
                        "patch_representations"]
                    writer.write_histograms(
                        step, {"patch_representations": patch_representations})

                if eval_metrics:
                    writer.write_scalars(step, eval_metrics)

    writer.flush()
    return state
def convert_nest(checkpoint_path, arch):
    """
    Expects path to checkpoint which is a dir containing 4 files like in each of these folders
        - https://console.cloud.google.com/storage/browser/gresearch/nest-checkpoints
    `arch` is needed to 
    Returns a state dict that can be used with `torch.nn.Module.load_state_dict`
    Hint: Follow timm.models.bak.nest.Nest.__init__ and
    https://github.com/google-research/nested-transformer/blob/main/models/nest_net.py
    """
    assert arch in ['nest_base', 'nest_small',
                    'nest_tiny'], "Your `arch` is not supported"

    flax_dict = checkpoint.load_state_dict(
        checkpoint_path)['optimizer']['target']
    state_dict = {}

    # Patch embedding
    state_dict['patch_embed.proj.weight'] = torch.tensor(
        flax_dict['PatchEmbedding_0']['Conv_0']['kernel']).permute(3, 2, 0, 1)
    state_dict['patch_embed.proj.bias'] = torch.tensor(
        flax_dict['PatchEmbedding_0']['Conv_0']['bias'])

    # Positional embeddings
    posemb_keys = [
        k for k in flax_dict.keys() if k.startswith('PositionEmbedding')
    ]
    for i, k in enumerate(posemb_keys):
        state_dict[f'levels.{i}.pos_embed'] = torch.tensor(
            flax_dict[k]['pos_embedding'])

    # Transformer encoders
    depths = arch_depths[arch]
    for level in range(len(depths)):
        for layer in range(depths[level]):
            global_layer_ix = sum(depths[:level]) + layer
            # Norms
            for i in range(2):
                state_dict[
                    f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.weight'] = torch.tensor(
                        flax_dict[f'EncoderNDBlock_{global_layer_ix}']
                        [f'LayerNorm_{i}']['scale'])
                state_dict[
                    f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.bias'] = torch.tensor(
                        flax_dict[f'EncoderNDBlock_{global_layer_ix}']
                        [f'LayerNorm_{i}']['bias'])
            # Attention qkv
            w_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][
                'MultiHeadAttention_0']['DenseGeneral_0']['kernel']
            w_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][
                'MultiHeadAttention_0']['DenseGeneral_1']['kernel']
            # Pay attention to dims here (maybe get pen and paper)
            w_kv = np.concatenate(np.split(w_kv, 2, -1), 1)
            w_qkv = np.concatenate([w_q, w_kv], 1)
            state_dict[
                f'levels.{level}.transformer_encoder.{layer}.attn.qkv.weight'] = torch.tensor(
                    w_qkv).flatten(1).permute(1, 0)
            b_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][
                'MultiHeadAttention_0']['DenseGeneral_0']['bias']
            b_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][
                'MultiHeadAttention_0']['DenseGeneral_1']['bias']
            # Pay attention to dims here (maybe get pen and paper)
            b_kv = np.concatenate(np.split(b_kv, 2, -1), 0)
            b_qkv = np.concatenate([b_q, b_kv], 0)
            state_dict[
                f'levels.{level}.transformer_encoder.{layer}.attn.qkv.bias'] = torch.tensor(
                    b_qkv).reshape(-1)
            # Attention proj
            w_proj = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][
                'MultiHeadAttention_0']['proj_kernel']
            w_proj = torch.tensor(w_proj).permute(2, 1, 0).flatten(1)
            state_dict[
                f'levels.{level}.transformer_encoder.{layer}.attn.proj.weight'] = w_proj
            state_dict[
                f'levels.{level}.transformer_encoder.{layer}.attn.proj.bias'] = torch.tensor(
                    flax_dict[f'EncoderNDBlock_{global_layer_ix}']
                    ['MultiHeadAttention_0']['bias'])
            # MLP
            for i in range(2):
                state_dict[
                    f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.weight'] = torch.tensor(
                        flax_dict[f'EncoderNDBlock_{global_layer_ix}']
                        ['MlpBlock_0'][f'Dense_{i}']['kernel']).permute(1, 0)
                state_dict[
                    f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
                        flax_dict[f'EncoderNDBlock_{global_layer_ix}']
                        ['MlpBlock_0'][f'Dense_{i}']['bias'])

    # Block aggregations (ConvPool)
    for level in range(1, len(depths)):
        # Convs
        state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor(
            flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute(
                3, 2, 0, 1)
        state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor(
            flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias'])
        # Norms
        state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor(
            flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale'])
        state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor(
            flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias'])

    # Final norm
    state_dict[f'norm.weight'] = torch.tensor(
        flax_dict['LayerNorm_0']['scale'])
    state_dict[f'norm.bias'] = torch.tensor(flax_dict['LayerNorm_0']['bias'])

    # Classifier
    state_dict['head.weight'] = torch.tensor(
        flax_dict['Dense_0']['kernel']).permute(1, 0)
    state_dict['head.bias'] = torch.tensor(flax_dict['Dense_0']['bias'])

    return state_dict