def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter): model = Model() if with_fsdp: if wrap_bn: model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False) if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True) with enable_wrap( wrapper_cls=FSDP, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, ): model.block1 = wrap(model.block1) model.block2 = wrap(model.block2) model.head = wrap(model.head) else: if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False) return model
def _create_model( with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): model = Model2() if with_model2 else Model() fsdp_config = None if with_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) fsdp_config = { "mixed_precision": False, "flatten_parameters": False, "reshard_after_forward": False, "bucket_cap_mb": 0, "force_input_to_fp32": True, # SyncBN needs this. } if with_fsdp and wrap_bn: model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config) if with_model2: model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config) if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2) if with_model2: model.block3 = checkpoint_wrapper(model.block3) if with_fsdp: with enable_wrap( wrapper_cls=FSDP, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, bucket_cap_mb=bucket_cap_mb, ): model.block1 = wrap(model.block1) model.block2 = wrap(model.block2) if with_model2: model.block3 = wrap(model.block3) model.head = wrap(model.head) return model
def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config): model = Model(model_hidden_dim) if with_fsdp: model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) model.stem = to_fsdp(model.stem, fsdp_config) model.blocks = to_fsdp(model.blocks, fsdp_config) model.head = to_fsdp(model.head, fsdp_config) else: if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) return model
def create_model(with_fsdp, with_checkpoint): model = Model() if with_fsdp: model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) model.stem = to_fsdp(model.stem) model.blocks = to_fsdp(model.blocks) model.head = to_fsdp(model.head) else: if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) return model
def _build_block(self, dpr: float, norm_layer) -> nn.Module: block = Block( dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, ) if self.trunk_config.CHECKPOINT_MLP: block.mlp = checkpoint_wrapper(block.mlp) if self.trunk_config.CHECKPOINT_BLOCK: block = checkpoint_wrapper(block) return block
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block( width_in, width_out, stride, params, bottleneck_multiplier, group_width ) block = auto_wrap_bn(block, single_rank_pg=False) if self.use_activation_checkpointing: # TODO - make this configurable block = checkpoint_wrapper(block, offload_to_cpu=False) with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config): block = wrap(block) return block
def create_any_stage( self, width_in: int, width_out: int, stride: int, depth: int, group_width: int, bottleneck_multiplier: float, params: Union[RegNetParams, AnyNetParams], stage_index: int = 0, checkpoints: List[int] = 0, ): assert sorted( checkpoints) == checkpoints, "Checkpoint indices should be sorted" with_checkpointing = len(checkpoints) > 0 block_delimiters = [depth] if len(checkpoints) == 0 else checkpoints any_stage = AnyStage() prev_depth = 0 for block_group_index, next_depth in enumerate(block_delimiters): block_group = nn.Sequential() for i in range(prev_depth, next_depth): block = self.create_block( width_in=width_in if i == 0 else width_out, width_out=width_out, stride=stride if i == 0 else 1, params=params, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, ) any_stage.stage_depth += block.depth block_group.add_module(f"block{stage_index}-{i}", block) prev_depth = next_depth if with_checkpointing: block_group = checkpoint_wrapper(block_group) block_group = fsdp_wrapper(block_group, **self.fsdp_config) any_stage.add_module(f"block{stage_index}-part{block_group_index}", block_group) return any_stage
def create_any_stage( self, width_in: int, width_out: int, stride: int, depth: int, group_width: int, bottleneck_multiplier: float, params: Union[RegNetParams, AnyNetParams], group_delimiters: List[int], group_checkpoint: List[bool], stage_index: int = 0, ): assert len(group_delimiters) == len(group_checkpoint) any_stage = AnyStage() prev_depth = 0 for group_index, next_depth in enumerate(group_delimiters): block_group = nn.Sequential() for i in range(prev_depth, next_depth): block = self.create_block( width_in=width_in if i == 0 else width_out, width_out=width_out, stride=stride if i == 0 else 1, params=params, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, ) any_stage.stage_depth += block.depth block_group.add_module(f"block{stage_index}-{i}", block) prev_depth = next_depth if group_checkpoint[group_index]: block_group = checkpoint_wrapper(block_group) block_group = fsdp_wrapper(block_group, **self.fsdp_config) any_stage.add_module(f"block{stage_index}-part{group_index}", block_group) return any_stage
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config): assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" trunk_config = model_config.TRUNK.REGNET if "name" in trunk_config: assert (trunk_config["name"] == "anynet" ), "Please use AnyNetParams or specify RegNetParams dictionary" if "name" in trunk_config and trunk_config["name"] == "anynet": params = AnyNetParams( depths=trunk_config["depths"], widths=trunk_config["widths"], group_widths=trunk_config["group_widths"], bottleneck_multipliers=trunk_config["bottleneck_multipliers"], strides=trunk_config["strides"], stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()], stem_width=trunk_config.get("stem_width", 32), block_type=BlockType[trunk_config.get( "block_type", "res_bottleneck_block").upper()], activation=ActivationType[trunk_config.get("activation", "relu").upper()], use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) else: params = RegNetParams( depth=trunk_config["depth"], w_0=trunk_config["w_0"], w_a=trunk_config["w_a"], w_m=trunk_config["w_m"], group_width=trunk_config["group_width"], bottleneck_multiplier=trunk_config.get("bottleneck_multiplier", 1.0), stem_type=StemType[trunk_config.get("stem_type", "simple_stem_in").upper()], stem_width=trunk_config.get("stem_width", 32), block_type=BlockType[trunk_config.get( "block_type", "res_bottleneck_block").upper()], activation=ActivationType[trunk_config.get("activation", "relu").upper()], use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) # Ad hoc stem # # Important: do NOT retain modules in self.stem or self.trunk_output. It may # seem to be harmless, but it appears that autograd will result in computing # grads in different order. Different ordering can cause deterministic OOM, # even when the peak memory otherwise is only 24GB out of 32GB. # # When debugging this, it is not enough to just dump the total module # params. You need to diff the module string representations. stem = factory.create_stem(params) # Instantiate all the AnyNet blocks in the trunk current_width, trunk_depth, blocks = params.stem_width, 0, [] for i, (width_out, stride, depth, group_width, bottleneck_multiplier) in enumerate(params.get_expanded_params()): new_stage = AnyStage( factory=factory, width_in=current_width, width_out=width_out, stride=stride, depth=depth, group_width=group_width, bottleneck_multiplier=bottleneck_multiplier, params=params, stage_index=i + 1, ) if isinstance(factory, RegnetFSDPBlocksFactory): if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING: logging.info("Using activation checkpointing") new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False) new_stage = fsdp_wrapper(module=new_stage, **model_config.FSDP_CONFIG) blocks.append((f"block{i + 1}", new_stage)) trunk_depth += blocks[-1][1].stage_depth current_width = width_out trunk_output = nn.Sequential(OrderedDict(blocks)) ################################################################################ # Now map the models to the structure we want to expose for SSL tasks # The upstream RegNet model is made of : # - `stem` # - n x blocks in trunk_output, named `block1, block2, ..` # We're only interested in the stem and successive blocks # everything else is not picked up on purpose feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)] for k, v in trunk_output.named_children(): assert k.startswith("block"), f"Unexpected layer name {k}" block_index = len(feature_blocks) + 1 feature_blocks.append((f"res{block_index}", v)) feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1)))) feature_blocks.append(("flatten", Flatten(1))) return nn.ModuleDict(feature_blocks), trunk_depth