def init_distributed_data_parallel_model(self): """ This method overloads the ClassificationTask class's method from ClassyVision. """ if not is_distributed_training_run(): return assert get_cuda_device_index( ) > -1, "Distributed training not setup correctly" # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root # flag to true. We need to set it back to None here. # Also, right now, the head's weight is only partially loaded from the checkpoint # because we dump the checkpoint after the head if wrapped, but loading it before # it is wrapped. # For very big models, we need re-work the checkpoint logic because we don't have # enough memory to load the entire model on one node. We need to use local_state_dict() # API to load checkpoint shards. for module in self.base_model.trunk.modules(): if isinstance(module, FSDP): module._is_root = None for module in self.base_model.heads.modules(): if isinstance(module, FSDP): module._is_root = None # Then, wrap the whole model. We replace the base_model since it is used # when checkpoint is taken. fsdp_config = self.config["MODEL"]["FSDP_CONFIG"] self.base_model = fsdp_wrapper(self.base_model, **fsdp_config) self.distributed_model = self.base_model assert is_valid_fsdp_model( self.distributed_model), "FSDP is not setup correctly"
def RegNetFSDP(model_config: AttrDict, model_name: str): """ Wrap the entire trunk since we need to load checkpoint before train_fsdp_task.py wrapping happens. """ module = _RegNetFSDP(model_config, model_name).cuda() return fsdp_wrapper(module, **model_config.FSDP_CONFIG)
def FSDPLinearEvalMLP( model_config: AttrDict, in_channels: int, dims: List[int], use_bn: bool = False, use_relu: bool = False, ): mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu) mlp = fsdp_auto_wrap_bn(mlp) return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def SwavPrototypesHeadFSDP( model_config: AttrDict, dims: List[int], use_bn: bool, num_clusters: int, use_bias: bool = True, return_embeddings: bool = True, skip_last_bn: bool = True, normalize_feats: bool = True, ): """ SwAV head specific FSDP wrapping: we keep the full precision for the prototypes This is important for convergence: Since we "normalize" this layer in the update hook, we need to keep its weights in full precision. It is output is going into the loss and used for clustering, so we need to have that in full precision as well. """ head = SwAVPrototypesHead( model_config=model_config, dims=dims, use_bn=use_bn, num_clusters=num_clusters, use_bias=use_bias, return_embeddings=return_embeddings, skip_last_bn=skip_last_bn, normalize_feats=normalize_feats, ) head = fsdp_auto_wrap_bn(head) prototypes_fp32_fsdp_config = model_config.FSDP_CONFIG.copy() prototypes_fp32_fsdp_config["flatten_parameters"] = False prototypes_fp32_fsdp_config["mixed_precision"] = False prototypes_fp32_fsdp_config["fp32_reduce_scatter"] = False prototypes_fp32_fsdp_config["compute_dtype"] = torch.float32 for j in range(head.nmb_heads): module = getattr(head, "prototypes" + str(j)) module = fsdp_wrapper(module, **prototypes_fp32_fsdp_config) setattr(head, "prototypes" + str(j), module) return fsdp_wrapper(head, **model_config.FSDP_CONFIG)
def __init__( self, model_config: AttrDict, in_channels: int, dims: List[int], use_bn: bool = False, use_relu: bool = False, ): super().__init__() mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu) mlp = fsdp_auto_wrap_bn(mlp) self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def _pretraining_worker( gpu_id: int, with_fsdp: bool, with_activation_checkpointing: bool, with_larc: bool, sync_file: str, result_file: str, ): init_distributed_on_file(world_size=2, gpu_id=gpu_id, sync_file=sync_file) torch.manual_seed(0) torch.backends.cudnn.deterministic = True # Create the inputs batch = torch.randn(size=(8, 3, 224, 224)).cuda() target = torch.tensor(0.0).cuda() # Create a fake model based on SWAV blocks config = TestRegnetFSDP._create_pretraining_config( with_fsdp, with_activation_checkpointing, with_larc=with_larc) model = build_model(config["MODEL"], config["OPTIMIZER"]) model = model.cuda() if with_fsdp: model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"]) optimizer = build_optimizer(config["OPTIMIZER"]) optimizer.set_param_groups(model.parameters()) # Run a few iterations and collect the losses losses = [] num_iterations = 5 for iteration in range(num_iterations): out = model(batch) loss = criterion(out[0], target) if gpu_id == 0: losses.append(loss.item()) optimizer.zero_grad() loss.backward() if iteration <= 2: for name, param in model.named_parameters(): if "prototypes" in name: param.grad = None optimizer.step(where=float(iteration / num_iterations)) # Store the losses in a file to compare several methods if gpu_id == 0: with open(result_file, "wb") as f: pickle.dump(losses, f)
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block(width_in, width_out, stride, params, bottleneck_multiplier, group_width) block = auto_wrap_bn(block, single_rank_pg=False) block = fsdp_wrapper(module=block, **self.fsdp_config) return block
def __init__( self, model_config: AttrDict, dims: List[int], use_bn: bool = False, use_relu: bool = False, use_dropout: bool = False, use_bias: bool = True, skip_last_layer_relu_bn: bool = True, ): super().__init__() mlp = MLP(model_config, dims, use_bn, use_relu, use_dropout, use_bias, skip_last_layer_relu_bn) mlp = auto_wrap_bn(mlp, single_rank_pg=False) self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block(width_in, width_out, stride, params, bottleneck_multiplier, group_width) block = fsdp_auto_wrap_bn(block) if self.fsdp_config.AUTO_WRAP_THRESHOLD > 0: block = auto_wrap_big_layers(block, self.fsdp_config) block = fsdp_wrapper(module=block, **self.fsdp_config) return block
def _distributed_worker(gpu_id: int, with_fsdp: bool, sync_file: str, result_file: str): torch.cuda.set_device(gpu_id) dist.init_process_group(backend="nccl", init_method="file://" + sync_file, world_size=2, rank=gpu_id) # Create the inputs torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(8, 3, 224, 224)).cuda() # Create a fake model based on SWAV blocks config = TestRegnetFSDP._create_config(with_fsdp) model = build_model(config["MODEL"], config["OPTIMIZER"]) model = model.cuda() if with_fsdp: model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"]) optimizer = optim.SGD(model.parameters(), lr=1e-2) # Run a few iterations and collect the losses losses = [] for iteration in range(5): out = model(batch) loss = criterion(out[0], torch.tensor(0.0).cuda()) if gpu_id == 0: losses.append(loss.item()) optimizer.zero_grad() loss.backward() if iteration <= 2: for name, param in model.named_parameters(): if "prototypes" in name: param.grad = None optimizer.step() # Store the losses in a file to compare several methods if gpu_id == 0: with open(result_file, "wb") as f: pickle.dump(losses, f)
def MLP_FSDP( model_config: AttrDict, dims: List[int], use_bn: bool = False, use_relu: bool = False, use_dropout: bool = False, use_bias: bool = True, skip_last_layer_relu_bn: bool = True, ): mlp = MLP( model_config, dims, use_bn, use_relu, use_dropout, use_bias, skip_last_layer_relu_bn, ) mlp = fsdp_auto_wrap_bn(mlp) return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def create_any_stage( self, width_in: int, width_out: int, stride: int, depth: int, group_width: int, bottleneck_multiplier: float, params: Union[RegNetParams, AnyNetParams], stage_index: int = 0, checkpoints: List[int] = 0, ): assert sorted( checkpoints) == checkpoints, "Checkpoint indices should be sorted" with_checkpointing = len(checkpoints) > 0 block_delimiters = [depth] if len(checkpoints) == 0 else checkpoints any_stage = AnyStage() prev_depth = 0 for block_group_index, next_depth in enumerate(block_delimiters): block_group = nn.Sequential() for i in range(prev_depth, next_depth): block = self.create_block( width_in=width_in if i == 0 else width_out, width_out=width_out, stride=stride if i == 0 else 1, params=params, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, ) any_stage.stage_depth += block.depth block_group.add_module(f"block{stage_index}-{i}", block) prev_depth = next_depth if with_checkpointing: block_group = checkpoint_wrapper(block_group) block_group = fsdp_wrapper(block_group, **self.fsdp_config) any_stage.add_module(f"block{stage_index}-part{block_group_index}", block_group) return any_stage
def create_any_stage( self, width_in: int, width_out: int, stride: int, depth: int, group_width: int, bottleneck_multiplier: float, params: Union[RegNetParams, AnyNetParams], group_delimiters: List[int], group_checkpoint: List[bool], stage_index: int = 0, ): assert len(group_delimiters) == len(group_checkpoint) any_stage = AnyStage() prev_depth = 0 for group_index, next_depth in enumerate(group_delimiters): block_group = nn.Sequential() for i in range(prev_depth, next_depth): block = self.create_block( width_in=width_in if i == 0 else width_out, width_out=width_out, stride=stride if i == 0 else 1, params=params, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, ) any_stage.stage_depth += block.depth block_group.add_module(f"block{stage_index}-{i}", block) prev_depth = next_depth if group_checkpoint[group_index]: block_group = checkpoint_wrapper(block_group) block_group = fsdp_wrapper(block_group, **self.fsdp_config) any_stage.add_module(f"block{stage_index}-part{group_index}", block_group) return any_stage
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config): assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" trunk_config = model_config.TRUNK.REGNET if "name" in trunk_config: assert (trunk_config["name"] == "anynet" ), "Please use AnyNetParams or specify RegNetParams dictionary" if "name" in trunk_config and trunk_config["name"] == "anynet": params = AnyNetParams( depths=trunk_config["depths"], widths=trunk_config["widths"], group_widths=trunk_config["group_widths"], bottleneck_multipliers=trunk_config["bottleneck_multipliers"], strides=trunk_config["strides"], stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()], stem_width=trunk_config.get("stem_width", 32), block_type=BlockType[trunk_config.get( "block_type", "res_bottleneck_block").upper()], activation=ActivationType[trunk_config.get("activation", "relu").upper()], use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) else: params = RegNetParams( depth=trunk_config["depth"], w_0=trunk_config["w_0"], w_a=trunk_config["w_a"], w_m=trunk_config["w_m"], group_width=trunk_config["group_width"], bottleneck_multiplier=trunk_config.get("bottleneck_multiplier", 1.0), stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()], stem_width=trunk_config.get("stem_width", 32), block_type=BlockType[trunk_config.get( "block_type", "res_bottleneck_block").upper()], activation=ActivationType[trunk_config.get("activation", "relu").upper()], use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) # Ad hoc stem # # Important: do NOT retain modules in self.stem or self.trunk_output. It may # seem to be harmless, but it appears that autograd will result in computing # grads in different order. Different ordering can cause deterministic OOM, # even when the peak memory otherwise is only 24GB out of 32GB. # # When debugging this, it is not enough to just dump the total module # params. You need to diff the module string representations. stem = factory.create_stem(params) # Instantiate all the AnyNet blocks in the trunk current_width, trunk_depth, blocks = params.stem_width, 0, [] for i, (width_out, stride, depth, group_width, bottleneck_multiplier) in enumerate(params.get_expanded_params()): new_stage = AnyStage( factory=factory, width_in=current_width, width_out=width_out, stride=stride, depth=depth, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, params=params, stage_index=i + 1, ) if isinstance(factory, RegnetFSDPBlocksFactory): if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING: logging.info("Using activation checkpointing") new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False) new_stage = fsdp_wrapper(module=new_stage, **model_config.FSDP_CONFIG) blocks.append((f"block{i + 1}", new_stage)) trunk_depth += blocks[-1][1].stage_depth current_width = width_out trunk_output = nn.Sequential(OrderedDict(blocks)) ################################################################################ # Now map the models to the structure we want to expose for SSL tasks # The upstream RegNet model is made of : # - `stem` # - n x blocks in trunk_output, named `block1, block2, ..` # We're only interested in the stem and successive blocks # everything else is not picked up on purpose feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)] for k, v in trunk_output.named_children(): assert k.startswith("block"), f"Unexpected layer name {k}" block_index = len(feature_blocks) + 1 feature_blocks.append((f"res{block_index}", v)) feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1)))) feature_blocks.append(("flatten", Flatten(1))) return nn.ModuleDict(feature_blocks), trunk_depth
def _worker(gpu_id: int, sync_file: str, world_size: int): torch.manual_seed(0) os.environ["RANK"] = str(gpu_id) init_distributed_on_file(world_size=world_size, gpu_id=gpu_id, sync_file=sync_file) torch.backends.cudnn.deterministic = True config = TestCheckpointConversion._create_fsdp_model_config( with_fsdp=True) model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG) optimizer = optim.SGD(model.parameters(), lr=1e-4) # Fake inputs num_iterations = 5 batch_size = 3 torch.manual_seed(gpu_id) fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96)) fake_targets = torch.randn(size=(num_iterations, batch_size)) # Fake training loop criterion = nn.MSELoss() for iteration in range(num_iterations): fake_input = fake_inputs[iteration].cuda(gpu_id) fake_target = fake_targets[iteration].cuda(gpu_id) output1, output2 = model(fake_input)[0] loss = criterion(output1.sum(axis=-1), fake_target) + criterion( output2.sum(axis=-1), fake_target) if gpu_id == 0: print(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() # Save a bunch of checkpoint, one by shard checkpoint_writer = CheckpointWriter( checkpoint_folder=".", is_final_train_phase=True, mode="iteration", mode_num=0, backend="disk", ) content = { "classy_state_dict": { "base_model": { "model": { "trunk": model.trunk.local_state_dict() }, "meta": { "trunk": model.trunk.local_metadata_dict() }, } } } checkpoint_writer.save_sharded_checkpoint(content, shard_rank=gpu_id, world_size=world_size) dist.barrier() print(os.listdir(".")) # Convert the checkpoint to consolidated and sliced checkpoints if gpu_id == 0: CheckpointFormatConverter.sharded_to_consolidated_checkpoint( "checkpoint.torch", "checkpoint_conso.torch") CheckpointFormatConverter.sharded_to_sliced_checkpoint( "checkpoint.torch", "checkpoint_sliced.torch") dist.barrier() print(os.listdir(".")) # Now create models initialized from the previous checkpoint and compare them fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id) shard_cp = CheckpointLoader.load_and_broadcast_init_weights( "checkpoint.torch", device=torch.device("cpu")) shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG) shard_model.init_model_from_weights_params_file(config, shard_cp) conso_cp = CheckpointLoader.load_and_broadcast_init_weights( "checkpoint_conso.torch", device=torch.device("cpu")) conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG) conso_model.init_model_from_weights_params_file(config, conso_cp) slice_cp = CheckpointLoader.load_and_broadcast_init_weights( "checkpoint_sliced.torch", device=torch.device("cpu")) slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id) slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG) slice_model.init_model_from_weights_params_file(config, slice_cp) # Verifying that the models are equivalent if gpu_id == 0: slice_state_dict = slice_model.local_state_dict() conso_state_dict = conso_model.local_state_dict() assert set(slice_state_dict.keys()) == set(conso_state_dict.keys()) for k in slice_state_dict.keys(): slice_val = slice_state_dict[k] conso_val = conso_state_dict[k] assert torch.allclose( slice_val, conso_val ), f"Difference for key {k}: {slice_val} VS {conso_val}" dist.barrier() with torch.no_grad(): ref_out = model.trunk(fake_test_input)[0] shard_out = shard_model.trunk(fake_test_input)[0] conso_out = conso_model.trunk(fake_test_input)[0] slice_out = slice_model.trunk(fake_test_input)[0] assert torch.allclose( ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}" assert torch.allclose( ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}" assert torch.allclose( ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"