def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter): model = Model() if with_fsdp: if wrap_bn: model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False) if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True) with enable_wrap( wrapper_cls=FSDP, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, ): model.block1 = wrap(model.block1) model.block2 = wrap(model.block2) model.head = wrap(model.head) else: if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False) return model
def _create_model( with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): model = Model2() if with_model2 else Model() fsdp_config = None if with_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) fsdp_config = { "mixed_precision": False, "flatten_parameters": False, "reshard_after_forward": False, "bucket_cap_mb": 0, "force_input_to_fp32": True, # SyncBN needs this. } if with_fsdp and wrap_bn: model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config) if with_model2: model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config) if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2) if with_model2: model.block3 = checkpoint_wrapper(model.block3) if with_fsdp: with enable_wrap( wrapper_cls=FSDP, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, bucket_cap_mb=bucket_cap_mb, ): model.block1 = wrap(model.block1) model.block2 = wrap(model.block2) if with_model2: model.block3 = wrap(model.block3) model.head = wrap(model.head) return model
def fsdp_wrapper(module): """Customer wrapper that does FSDP + checkpoint at the same time Currently not used. Will be used in the next commit. Included here to check the imports. """ fsdp_config = { "wrapper_cls": fsdp_wrapper, "mixed_precision": True, "flatten_parameters": True, } with enable_wrap(fsdp_config): wrap() return FSDP(checkpoint_wrapper(module))
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block(width_in, width_out, stride, params, bottleneck_multiplier, group_width) block = auto_wrap_bn(block, single_rank_pg=False) with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config): block = wrap(block) return block
def __init__( self, model_config, width_in: int, width_out: int, stride: int, depth: int, block_constructor: nn.Module, activation: nn.Module, bot_mul: float, group_width: int, params: "RegNetParams", stage_index: int = 0, ): super().__init__() self.stage_depth = 0 fsdp_config = { "wrapper_cls": fsdp_wrapper, } fsdp_config.update(model_config.FSDP_CONFIG) for i in range(depth): # Make a block and move it to cuda since shard-as-we-build of FSDP needs # cuda to do dist.all_gather() call. block = block_constructor( width_in if i == 0 else width_out, width_out, stride if i == 0 else 1, params.bn_epsilon, params.bn_momentum, activation, bot_mul, group_width, params.se_ratio, ).cuda() # Init weight before wrapping and sharding. init_weights(block) # Now, wrap it with fsdp+checkpoint, which will perform the sharding. block = auto_wrap_bn(block) with enable_wrap(**fsdp_config): block = wrap(block) self.stage_depth += block.depth self.add_module(f"block{stage_index}-{i}", block)
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block( width_in, width_out, stride, params, bottleneck_multiplier, group_width ) block = auto_wrap_bn(block, single_rank_pg=False) if self.use_activation_checkpointing: # TODO - make this configurable block = checkpoint_wrapper(block, offload_to_cpu=False) with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config): block = wrap(block) return block
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config): assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" trunk_config = model_config.TRUNK.REGNET if "name" in trunk_config: assert (trunk_config["name"] == "anynet" ), "Please use AnyNetParams or specify RegNetParams dictionary" if "name" in trunk_config and trunk_config["name"] == "anynet": params = AnyNetParams( depths=trunk_config["depths"], widths=trunk_config["widths"], group_widths=trunk_config["group_widths"], bottleneck_multipliers=trunk_config["bottleneck_multipliers"], strides=trunk_config["strides"], stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()], stem_width=trunk_config.get("stem_width", 32), block_type=BlockType[trunk_config.get( "block_type", "res_bottleneck_block").upper()], activation=ActivationType[trunk_config.get("activation", "relu").upper()], use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) else: params = RegNetParams( depth=trunk_config["depth"], w_0=trunk_config["w_0"], w_a=trunk_config["w_a"], w_m=trunk_config["w_m"], group_width=trunk_config["group_width"], bottleneck_multiplier=trunk_config.get("bottleneck_multiplier", 1.0), stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()], stem_width=trunk_config.get("stem_width", 32), block_type=BlockType[trunk_config.get( "block_type", "res_bottleneck_block").upper()], activation=ActivationType[trunk_config.get("activation", "relu").upper()], use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) # Ad hoc stem # # Important: do NOT retain modules in self.stem or self.trunk_output. It may # seem to be harmless, but it appears that autograd will result in computing # grads in different order. Different ordering can cause deterministic OOM, # even when the peak memory otherwise is only 24GB out of 32GB. # # When debugging this, it is not enough to just dump the total module # params. You need to diff the module string representations. stem = factory.create_stem(params) # Instantiate all the AnyNet blocks in the trunk current_width, trunk_depth, blocks = params.stem_width, 0, [] for i, (width_out, stride, depth, group_width, bottleneck_multiplier) in enumerate(params.get_expanded_params()): new_stage = AnyStage( factory=factory, width_in=current_width, width_out=width_out, stride=stride, depth=depth, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, params=params, stage_index=i + 1, ) if isinstance(factory, RegnetFSDPBlocksFactory): if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING: logging.info("Using activation checkpointing") new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False) with enable_wrap(wrapper_cls=fsdp_wrapper, **model_config.FSDP_CONFIG): new_stage = wrap(new_stage) blocks.append(( f"block{i + 1}", new_stage, )) trunk_depth += blocks[-1][1].stage_depth current_width = width_out trunk_output = nn.Sequential(OrderedDict(blocks)) ################################################################################ # Now map the models to the structure we want to expose for SSL tasks # The upstream RegNet model is made of : # - `stem` # - n x blocks in trunk_output, named `block1, block2, ..` # We're only interested in the stem and successive blocks # everything else is not picked up on purpose feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)] for k, v in trunk_output.named_children(): assert k.startswith("block"), f"Unexpected layer name {k}" block_index = len(feature_blocks) + 1 feature_blocks.append((f"res{block_index}", v)) feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1)))) feature_blocks.append(("flatten", Flatten(1))) return nn.ModuleDict(feature_blocks), trunk_depth
def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module: with enable_wrap(wrapper_cls=_FSDP, **self._fsdp_kwargs): wrapped_module = wrap(module) return wrapped_module
def main(local_rank, *args): torch.backends.cudnn.benchmark = True init_method = "tcp://%s:%s" % ("0.0.0.0", "9999") torch.distributed.init_process_group(backend="nccl", rank=local_rank, world_size=8, init_method=init_method) print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" % (get_time_string(), local_rank)) device = torch.device( f'cuda:{local_rank}') # Unique only on individual node. torch.cuda.set_device(device) torch.cuda.set_device(device) fsdp_params = dict(mixed_precision=True, flatten_parameters=True, bucket_cap_mb=25, reshard_after_forward=False, fp32_reduce_scatter=False, cpu_offload=False, move_grads_to_cpu=False, process_group=torch.distributed.group.WORLD) with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params): nn_model = nn.Sequential( nn.Linear(200, 200), wrap( checkpoint_wrapper(nn.Sequential( nn.Linear(200, 200), nn.Linear(200, 200), wrap( checkpoint_wrapper(nn.Linear(200, 200), offload_to_cpu=True)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.Linear(200, 200)), offload_to_cpu=True)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda() model = FullyShardedDDP(nn_model, **fsdp_params) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-7, weight_decay=1e-2, betas=(0.9, 0.99)) optimizer.zero_grad(set_to_none=True) for i in range(1000): optimizer.zero_grad(set_to_none=True) fake_inputs = torch.randn(32, 200, device=device) fake_labels = torch.randn(32, 64, device=device) outputs = model(fake_inputs) loss = ((outputs - fake_labels)**2).mean() loss.backward() model.clip_grad_norm_(1.0) optimizer.step() if i % 100 == 0: print("Loss = %s, rank = %s" % (loss.item(), local_rank)) state_dict = model.state_dict() nn_model = nn.Sequential( nn.Linear(200, 200), nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200), nn.Linear(200, 200), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.Linear(200, 200)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda() nn_model.load_state_dict(state_dict) print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), numel(nn_model) / 1_000_000))