コード例 #1
0
ファイル: linear_eval_mlp.py プロジェクト: zlapp/vissl
def FSDPLinearEvalMLP(
    model_config: AttrDict,
    in_channels: int,
    dims: List[int],
    use_bn: bool = False,
    use_relu: bool = False,
):
    mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
    mlp = fsdp_auto_wrap_bn(mlp)
    return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
コード例 #2
0
ファイル: linear_eval_mlp.py プロジェクト: worosom/vissl
 def __init__(
     self,
     model_config: AttrDict,
     in_channels: int,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
 ):
     super().__init__()
     mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
     mlp = fsdp_auto_wrap_bn(mlp)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
コード例 #3
0
ファイル: regnet_fsdp.py プロジェクト: QuentinDuval/vissl
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bottleneck_multiplier, group_width)
     block = fsdp_auto_wrap_bn(block)
     if self.fsdp_config.AUTO_WRAP_THRESHOLD > 0:
         block = auto_wrap_big_layers(block, self.fsdp_config)
     block = fsdp_wrapper(module=block, **self.fsdp_config)
     return block
コード例 #4
0
def SwavPrototypesHeadFSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool,
    num_clusters: int,
    use_bias: bool = True,
    return_embeddings: bool = True,
    skip_last_bn: bool = True,
    normalize_feats: bool = True,
):
    """
    SwAV head specific FSDP wrapping: we keep the full precision for the prototypes

    This is important for convergence:
    Since we "normalize" this layer in the update hook, we need to keep its
    weights in full precision. It is output is going into the loss and used
    for clustering, so we need to have that in full precision as well.
    """

    head = SwAVPrototypesHead(
        model_config=model_config,
        dims=dims,
        use_bn=use_bn,
        num_clusters=num_clusters,
        use_bias=use_bias,
        return_embeddings=return_embeddings,
        skip_last_bn=skip_last_bn,
        normalize_feats=normalize_feats,
    )
    head = fsdp_auto_wrap_bn(head)

    prototypes_fp32_fsdp_config = model_config.FSDP_CONFIG.copy()
    prototypes_fp32_fsdp_config["flatten_parameters"] = False
    prototypes_fp32_fsdp_config["mixed_precision"] = False
    prototypes_fp32_fsdp_config["fp32_reduce_scatter"] = False
    prototypes_fp32_fsdp_config["compute_dtype"] = torch.float32
    for j in range(head.nmb_heads):
        module = getattr(head, "prototypes" + str(j))
        module = fsdp_wrapper(module, **prototypes_fp32_fsdp_config)
        setattr(head, "prototypes" + str(j), module)

    return fsdp_wrapper(head, **model_config.FSDP_CONFIG)
コード例 #5
0
def MLP_FSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool = False,
    use_relu: bool = False,
    use_dropout: bool = False,
    use_bias: bool = True,
    skip_last_layer_relu_bn: bool = True,
):
    mlp = MLP(
        model_config,
        dims,
        use_bn,
        use_relu,
        use_dropout,
        use_bias,
        skip_last_layer_relu_bn,
    )
    mlp = fsdp_auto_wrap_bn(mlp)
    return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
コード例 #6
0
ファイル: mlp.py プロジェクト: worosom/vissl
 def __init__(
     self,
     model_config: AttrDict,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
     use_dropout: bool = False,
     use_bias: bool = True,
     skip_last_layer_relu_bn: bool = True,
 ):
     super().__init__()
     mlp = MLP(
         model_config,
         dims,
         use_bn,
         use_relu,
         use_dropout,
         use_bias,
         skip_last_layer_relu_bn,
     )
     mlp = fsdp_auto_wrap_bn(mlp)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
コード例 #7
0
ファイル: regnet_fsdp.py プロジェクト: QuentinDuval/vissl
 def create_stem(self, params: Union[RegNetParams, AnyNetParams]):
     stem = super().create_stem(params)
     stem = fsdp_auto_wrap_bn(stem)
     return stem