def FSDPLinearEvalMLP( model_config: AttrDict, in_channels: int, dims: List[int], use_bn: bool = False, use_relu: bool = False, ): mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu) mlp = fsdp_auto_wrap_bn(mlp) return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def __init__( self, model_config: AttrDict, in_channels: int, dims: List[int], use_bn: bool = False, use_relu: bool = False, ): super().__init__() mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu) mlp = fsdp_auto_wrap_bn(mlp) self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block(width_in, width_out, stride, params, bottleneck_multiplier, group_width) block = fsdp_auto_wrap_bn(block) if self.fsdp_config.AUTO_WRAP_THRESHOLD > 0: block = auto_wrap_big_layers(block, self.fsdp_config) block = fsdp_wrapper(module=block, **self.fsdp_config) return block
def SwavPrototypesHeadFSDP( model_config: AttrDict, dims: List[int], use_bn: bool, num_clusters: int, use_bias: bool = True, return_embeddings: bool = True, skip_last_bn: bool = True, normalize_feats: bool = True, ): """ SwAV head specific FSDP wrapping: we keep the full precision for the prototypes This is important for convergence: Since we "normalize" this layer in the update hook, we need to keep its weights in full precision. It is output is going into the loss and used for clustering, so we need to have that in full precision as well. """ head = SwAVPrototypesHead( model_config=model_config, dims=dims, use_bn=use_bn, num_clusters=num_clusters, use_bias=use_bias, return_embeddings=return_embeddings, skip_last_bn=skip_last_bn, normalize_feats=normalize_feats, ) head = fsdp_auto_wrap_bn(head) prototypes_fp32_fsdp_config = model_config.FSDP_CONFIG.copy() prototypes_fp32_fsdp_config["flatten_parameters"] = False prototypes_fp32_fsdp_config["mixed_precision"] = False prototypes_fp32_fsdp_config["fp32_reduce_scatter"] = False prototypes_fp32_fsdp_config["compute_dtype"] = torch.float32 for j in range(head.nmb_heads): module = getattr(head, "prototypes" + str(j)) module = fsdp_wrapper(module, **prototypes_fp32_fsdp_config) setattr(head, "prototypes" + str(j), module) return fsdp_wrapper(head, **model_config.FSDP_CONFIG)
def MLP_FSDP( model_config: AttrDict, dims: List[int], use_bn: bool = False, use_relu: bool = False, use_dropout: bool = False, use_bias: bool = True, skip_last_layer_relu_bn: bool = True, ): mlp = MLP( model_config, dims, use_bn, use_relu, use_dropout, use_bias, skip_last_layer_relu_bn, ) mlp = fsdp_auto_wrap_bn(mlp) return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def __init__( self, model_config: AttrDict, dims: List[int], use_bn: bool = False, use_relu: bool = False, use_dropout: bool = False, use_bias: bool = True, skip_last_layer_relu_bn: bool = True, ): super().__init__() mlp = MLP( model_config, dims, use_bn, use_relu, use_dropout, use_bias, skip_last_layer_relu_bn, ) mlp = fsdp_auto_wrap_bn(mlp) self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def create_stem(self, params: Union[RegNetParams, AnyNetParams]): stem = super().create_stem(params) stem = fsdp_auto_wrap_bn(stem) return stem