def test_load_state_dict(self): base_dir = tempfile.mkdtemp() state = TrainState(step=1) ckpt = checkpoint.Checkpoint(base_dir) # Initializes. state = ckpt.restore_or_initialize(state) # Load via load_state_dict(). flax_dict = checkpoint.load_state_dict(base_dir) self.assertEqual(flax_dict, dict(step=1)) with self.assertRaisesRegexp(FileNotFoundError, r"^No checkpoint found"): checkpoint.load_state_dict(tempfile.mkdtemp())
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning( "%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState(step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer(batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like( previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError( "A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError( "A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats[ "patch_representations"] writer.write_histograms( step, {"patch_representations": patch_representations}) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state
def convert_nest(checkpoint_path, arch): """ Expects path to checkpoint which is a dir containing 4 files like in each of these folders - https://console.cloud.google.com/storage/browser/gresearch/nest-checkpoints `arch` is needed to Returns a state dict that can be used with `torch.nn.Module.load_state_dict` Hint: Follow timm.models.bak.nest.Nest.__init__ and https://github.com/google-research/nested-transformer/blob/main/models/nest_net.py """ assert arch in ['nest_base', 'nest_small', 'nest_tiny'], "Your `arch` is not supported" flax_dict = checkpoint.load_state_dict( checkpoint_path)['optimizer']['target'] state_dict = {} # Patch embedding state_dict['patch_embed.proj.weight'] = torch.tensor( flax_dict['PatchEmbedding_0']['Conv_0']['kernel']).permute(3, 2, 0, 1) state_dict['patch_embed.proj.bias'] = torch.tensor( flax_dict['PatchEmbedding_0']['Conv_0']['bias']) # Positional embeddings posemb_keys = [ k for k in flax_dict.keys() if k.startswith('PositionEmbedding') ] for i, k in enumerate(posemb_keys): state_dict[f'levels.{i}.pos_embed'] = torch.tensor( flax_dict[k]['pos_embedding']) # Transformer encoders depths = arch_depths[arch] for level in range(len(depths)): for layer in range(depths[level]): global_layer_ix = sum(depths[:level]) + layer # Norms for i in range(2): state_dict[ f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.weight'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'] [f'LayerNorm_{i}']['scale']) state_dict[ f'levels.{level}.transformer_encoder.{layer}.norm{i+1}.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'] [f'LayerNorm_{i}']['bias']) # Attention qkv w_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][ 'MultiHeadAttention_0']['DenseGeneral_0']['kernel'] w_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][ 'MultiHeadAttention_0']['DenseGeneral_1']['kernel'] # Pay attention to dims here (maybe get pen and paper) w_kv = np.concatenate(np.split(w_kv, 2, -1), 1) w_qkv = np.concatenate([w_q, w_kv], 1) state_dict[ f'levels.{level}.transformer_encoder.{layer}.attn.qkv.weight'] = torch.tensor( w_qkv).flatten(1).permute(1, 0) b_q = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][ 'MultiHeadAttention_0']['DenseGeneral_0']['bias'] b_kv = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][ 'MultiHeadAttention_0']['DenseGeneral_1']['bias'] # Pay attention to dims here (maybe get pen and paper) b_kv = np.concatenate(np.split(b_kv, 2, -1), 0) b_qkv = np.concatenate([b_q, b_kv], 0) state_dict[ f'levels.{level}.transformer_encoder.{layer}.attn.qkv.bias'] = torch.tensor( b_qkv).reshape(-1) # Attention proj w_proj = flax_dict[f'EncoderNDBlock_{global_layer_ix}'][ 'MultiHeadAttention_0']['proj_kernel'] w_proj = torch.tensor(w_proj).permute(2, 1, 0).flatten(1) state_dict[ f'levels.{level}.transformer_encoder.{layer}.attn.proj.weight'] = w_proj state_dict[ f'levels.{level}.transformer_encoder.{layer}.attn.proj.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'] ['MultiHeadAttention_0']['bias']) # MLP for i in range(2): state_dict[ f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.weight'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'] ['MlpBlock_0'][f'Dense_{i}']['kernel']).permute(1, 0) state_dict[ f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor( flax_dict[f'EncoderNDBlock_{global_layer_ix}'] ['MlpBlock_0'][f'Dense_{i}']['bias']) # Block aggregations (ConvPool) for level in range(1, len(depths)): # Convs state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor( flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute( 3, 2, 0, 1) state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor( flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias']) # Norms state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor( flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale']) state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor( flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias']) # Final norm state_dict[f'norm.weight'] = torch.tensor( flax_dict['LayerNorm_0']['scale']) state_dict[f'norm.bias'] = torch.tensor(flax_dict['LayerNorm_0']['bias']) # Classifier state_dict['head.weight'] = torch.tensor( flax_dict['Dense_0']['kernel']).permute(1, 0) state_dict['head.bias'] = torch.tensor(flax_dict['Dense_0']['bias']) return state_dict