Exemple #1
0
def FSDPLinearEvalMLP(
    model_config: AttrDict,
    in_channels: int,
    dims: List[int],
    use_bn: bool = False,
    use_relu: bool = False,
):
    mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
    mlp = fsdp_auto_wrap_bn(mlp)
    return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
Exemple #2
0
 def __init__(
     self,
     model_config: AttrDict,
     in_channels: int,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
 ):
     super().__init__()
     mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
     mlp = fsdp_auto_wrap_bn(mlp)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
Exemple #3
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bottleneck_multiplier, group_width)
     block = fsdp_auto_wrap_bn(block)
     if self.fsdp_config.AUTO_WRAP_THRESHOLD > 0:
         block = auto_wrap_big_layers(block, self.fsdp_config)
     block = fsdp_wrapper(module=block, **self.fsdp_config)
     return block
def SwavPrototypesHeadFSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool,
    num_clusters: int,
    use_bias: bool = True,
    return_embeddings: bool = True,
    skip_last_bn: bool = True,
    normalize_feats: bool = True,
):
    """
    SwAV head specific FSDP wrapping: we keep the full precision for the prototypes

    This is important for convergence:
    Since we "normalize" this layer in the update hook, we need to keep its
    weights in full precision. It is output is going into the loss and used
    for clustering, so we need to have that in full precision as well.
    """

    head = SwAVPrototypesHead(
        model_config=model_config,
        dims=dims,
        use_bn=use_bn,
        num_clusters=num_clusters,
        use_bias=use_bias,
        return_embeddings=return_embeddings,
        skip_last_bn=skip_last_bn,
        normalize_feats=normalize_feats,
    )
    head = fsdp_auto_wrap_bn(head)

    prototypes_fp32_fsdp_config = model_config.FSDP_CONFIG.copy()
    prototypes_fp32_fsdp_config["flatten_parameters"] = False
    prototypes_fp32_fsdp_config["mixed_precision"] = False
    prototypes_fp32_fsdp_config["fp32_reduce_scatter"] = False
    prototypes_fp32_fsdp_config["compute_dtype"] = torch.float32
    for j in range(head.nmb_heads):
        module = getattr(head, "prototypes" + str(j))
        module = fsdp_wrapper(module, **prototypes_fp32_fsdp_config)
        setattr(head, "prototypes" + str(j), module)

    return fsdp_wrapper(head, **model_config.FSDP_CONFIG)
Exemple #5
0
def MLP_FSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool = False,
    use_relu: bool = False,
    use_dropout: bool = False,
    use_bias: bool = True,
    skip_last_layer_relu_bn: bool = True,
):
    mlp = MLP(
        model_config,
        dims,
        use_bn,
        use_relu,
        use_dropout,
        use_bias,
        skip_last_layer_relu_bn,
    )
    mlp = fsdp_auto_wrap_bn(mlp)
    return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
Exemple #6
0
 def __init__(
     self,
     model_config: AttrDict,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
     use_dropout: bool = False,
     use_bias: bool = True,
     skip_last_layer_relu_bn: bool = True,
 ):
     super().__init__()
     mlp = MLP(
         model_config,
         dims,
         use_bn,
         use_relu,
         use_dropout,
         use_bias,
         skip_last_layer_relu_bn,
     )
     mlp = fsdp_auto_wrap_bn(mlp)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
Exemple #7
0
 def create_stem(self, params: Union[RegNetParams, AnyNetParams]):
     stem = super().create_stem(params)
     stem = fsdp_auto_wrap_bn(stem)
     return stem