Exemple #1
0
 def __init__(self,
              spatial_dims: int,
              in_channels: int,
              out_channels: int,
              kernel_size: Union[Sequence[int], int],
              stride: Union[Sequence[int], int],
              num_groups: int,
              norm_name: str,
              is_prunable: bool = False):
     super(UnetResBlockEx, self).__init__(
         spatial_dims=spatial_dims,
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=kernel_size,
         stride=stride,
         norm_name=norm_name,
     )
     self.conv1 = get_conv_layer(
         spatial_dims,
         in_channels,
         out_channels,
         kernel_size=kernel_size,
         stride=stride,
         conv_only=True,
         num_groups=num_groups,
         is_prunable=is_prunable,
     )
     self.conv2 = get_conv_layer(
         spatial_dims,
         out_channels,
         out_channels,
         kernel_size=kernel_size,
         stride=1,
         conv_only=True,
         num_groups=num_groups,
         is_prunable=is_prunable,
     )
     self.conv3 = get_conv_layer(
         spatial_dims,
         in_channels,
         out_channels,
         kernel_size=1,
         stride=stride,
         conv_only=True,
         num_groups=num_groups,
         is_prunable=is_prunable,
     )
     self.lrelu = get_act_layer(("leakyrelu", {
         "inplace": True,
         "negative_slope": 0.01
     }))
     self.norm1 = get_norm_layer(spatial_dims, out_channels, norm_name)
     self.norm2 = get_norm_layer(spatial_dims, out_channels, norm_name)
     self.norm3 = get_norm_layer(spatial_dims, out_channels, norm_name)
     self.downsample = in_channels != out_channels
     stride_np = np.atleast_1d(stride)
     if not np.all(stride_np == 1):
         self.downsample = True
Exemple #2
0
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        dropout_rate: float = 0.0,
        act: Union[Tuple, str] = "GELU",
        dropout_mode="vit",
    ) -> None:
        """
        Args:
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
            dropout_rate: faction of the input units to drop.
            act: activation type and arguments. Defaults to GELU.
            dropout_mode: dropout mode, can be "vit" or "swin".
                "vit" mode uses two dropout instances as implemented in
                https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
                "swin" corresponds to one instance as implemented in
                https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23


        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")
        mlp_dim = mlp_dim or hidden_size
        self.linear1 = nn.Linear(hidden_size, mlp_dim)
        self.linear2 = nn.Linear(mlp_dim, hidden_size)
        self.fn = get_act_layer(act)
        self.drop1 = nn.Dropout(dropout_rate)
        dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
        if dropout_opt == "vit":
            self.drop2 = nn.Dropout(dropout_rate)
        elif dropout_opt == "swin":
            self.drop2 = self.drop1
        else:
            raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
Exemple #3
0
 def test_act_layer(self, input_param, expected):
     layer = get_act_layer(**input_param)
     self.assertEqual(f"{layer}", expected)