def _checkpoint_model(self, task, train_phase_idx, mode_frequency, mode_num, mode="phase"): """ Checkpoint model. Can be called in 3 possible scenarios: 1. If training becomes NaN, then we checkpoint the model to facilitate debugging 2. After every N epochs (CHECKPOINT_FREQ), model state is checkpointed. 3. If user wants to checkpoint during the epoch (ie. after every few training iterations, the model state is checkpointed.) Args: task: Self-supervision task that hold information about training iteration, epoch number etc. train_phase_idx (int): current training phase number. Starts from 0 mode_frequency (int): mode can be "phase" or "iteration". Frequency of checkpointing for the given mode mode_num (int): for the checkpointing mode (phase or iteration), the number of phase or iteration at which checkpointing is being done """ phase_idx = task.phase_idx # num_train_phases = num_epochs * num_phases_per_epoch # For OSS use, num_train_phases will be equal to num_epochs num_train_phases = task.num_train_phases # check if we need to checkpoint this phase is_checkpointing_phase = is_checkpoint_phase(mode_num, mode_frequency, train_phase_idx, num_train_phases, mode) is_final_train_phase = ((train_phase_idx == (num_train_phases - 1)) and task.train and mode == "phase") # handle checkpoint: if task.train and (is_final_train_phase or is_checkpointing_phase): # - if sharded state consolidate the state # /!\ All the ranks have to participate if hasattr(task.optimizer, "consolidate_state_dict") and mode != "phase": logging.info( f"[{mode}: {mode_num}] Consolidating sharded state on all replicas" ) task.optimizer.consolidate_state_dict() # Depending on whether we are in FSDP mode or not # - save the checkpoint on the primary rank # - save the sharded checkpoint on all ranks if is_primary() or isinstance(task.base_model, FSDP): checkpoint_folder = task.checkpoint_folder logging.info( f"[{mode}: {mode_num}] Saving checkpoint to {checkpoint_folder}" ) model_state_dict = task.get_classy_state() # phase_idx is already incremented at the beginning of phase but if we # are checkpointing at an iteration in the middle of phase, we should not # save the incremented phase_idx as it will incorrectly assume that model # trained for that phase already. if mode == "iteration": model_state_dict[ "phase_idx"] = model_state_dict["phase_idx"] - 1 if task.train: train_phase_idx = train_phase_idx - 1 model_state_dict["train_phase_idx"] = train_phase_idx restart_phase = phase_idx - 1 restart_iteration = task.iteration # When loading from a phase checkpoint: else: restart_phase = phase_idx restart_iteration = task.iteration checkpoint_content = { "phase_idx": restart_phase, "iteration": restart_iteration, "loss": task.loss.state_dict(), "iteration_num": task.local_iteration_num, "train_phase_idx": train_phase_idx, "classy_state_dict": model_state_dict, } checkpoint_writer = CheckpointWriter( checkpoint_folder=checkpoint_folder, is_final_train_phase=is_final_train_phase, mode=mode, mode_num=mode_num, backend=task.config["CHECKPOINT"]["BACKEND"], ) if isinstance(task.base_model, FSDP): _, rank = get_machine_local_and_dist_rank() checkpoint_writer.save_sharded_checkpoint( content=checkpoint_content, shard_rank=rank, world_size=self.world_size, ) else: checkpoint_writer.save_consolidated_checkpoint( checkpoint_content)
def _worker(gpu_id: int, sync_file: str, world_size: int): torch.manual_seed(0) os.environ["RANK"] = str(gpu_id) init_distributed_on_file(world_size=world_size, gpu_id=gpu_id, sync_file=sync_file) torch.backends.cudnn.deterministic = True config = TestCheckpointConversion._create_fsdp_model_config( with_fsdp=True) model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG) optimizer = optim.SGD(model.parameters(), lr=1e-4) # Fake inputs num_iterations = 5 batch_size = 3 torch.manual_seed(gpu_id) fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96)) fake_targets = torch.randn(size=(num_iterations, batch_size)) # Fake training loop criterion = nn.MSELoss() for iteration in range(num_iterations): fake_input = fake_inputs[iteration].cuda(gpu_id) fake_target = fake_targets[iteration].cuda(gpu_id) output1, output2 = model(fake_input)[0] loss = criterion(output1.sum(axis=-1), fake_target) + criterion( output2.sum(axis=-1), fake_target) if gpu_id == 0: print(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() # Save a bunch of checkpoint, one by shard checkpoint_writer = CheckpointWriter( checkpoint_folder=".", is_final_train_phase=True, mode="iteration", mode_num=0, backend="disk", ) content = { "classy_state_dict": { "base_model": { "model": { "trunk": model.trunk.local_state_dict() }, "meta": { "trunk": model.trunk.local_metadata_dict() }, } } } checkpoint_writer.save_sharded_checkpoint(content, shard_rank=gpu_id, world_size=world_size) dist.barrier() print(os.listdir(".")) # Convert the checkpoint to consolidated and sliced checkpoints if gpu_id == 0: CheckpointFormatConverter.sharded_to_consolidated_checkpoint( "checkpoint.torch", "checkpoint_conso.torch") CheckpointFormatConverter.sharded_to_sliced_checkpoint( "checkpoint.torch", "checkpoint_sliced.torch") dist.barrier() print(os.listdir(".")) # Now create models initialized from the previous checkpoint and compare them fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id) shard_cp = CheckpointLoader.load_and_broadcast_init_weights( "checkpoint.torch", device=torch.device("cpu")) shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG) shard_model.init_model_from_weights_params_file(config, shard_cp) conso_cp = CheckpointLoader.load_and_broadcast_init_weights( "checkpoint_conso.torch", device=torch.device("cpu")) conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG) conso_model.init_model_from_weights_params_file(config, conso_cp) slice_cp = CheckpointLoader.load_and_broadcast_init_weights( "checkpoint_sliced.torch", device=torch.device("cpu")) slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG) slice_model.init_model_from_weights_params_file(config, slice_cp) # Verifying that the models are equivalent if gpu_id == 0: slice_state_dict = slice_model.local_state_dict() conso_state_dict = conso_model.local_state_dict() assert set(slice_state_dict.keys()) == set(conso_state_dict.keys()) for k in slice_state_dict.keys(): slice_val = slice_state_dict[k] conso_val = conso_state_dict[k] assert torch.allclose( slice_val, conso_val ), f"Difference for key {k}: {slice_val} VS {conso_val}" dist.barrier() with torch.no_grad(): ref_out = model.trunk(fake_test_input)[0] shard_out = shard_model.trunk(fake_test_input)[0] conso_out = conso_model.trunk(fake_test_input)[0] slice_out = slice_model.trunk(fake_test_input)[0] assert torch.allclose( ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}" assert torch.allclose( ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}" assert torch.allclose( ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"