def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], num_groups: int, norm_name: str, is_prunable: bool = False): super(UnetResBlockEx, self).__init__( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ) self.conv1 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=kernel_size, stride=stride, conv_only=True, num_groups=num_groups, is_prunable=is_prunable, ) self.conv2 = get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, conv_only=True, num_groups=num_groups, is_prunable=is_prunable, ) self.conv3 = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, conv_only=True, num_groups=num_groups, is_prunable=is_prunable, ) self.lrelu = get_act_layer(("leakyrelu", { "inplace": True, "negative_slope": 0.01 })) self.norm1 = get_norm_layer(spatial_dims, out_channels, norm_name) self.norm2 = get_norm_layer(spatial_dims, out_channels, norm_name) self.norm3 = get_norm_layer(spatial_dims, out_channels, norm_name) self.downsample = in_channels != out_channels stride_np = np.atleast_1d(stride) if not np.all(stride_np == 1): self.downsample = True
def __init__( self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0, act: Union[Tuple, str] = "GELU", dropout_mode="vit", ) -> None: """ Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. act: activation type and arguments. Defaults to GELU. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 "swin" corresponds to one instance as implemented in https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": self.drop2 = nn.Dropout(dropout_rate) elif dropout_opt == "swin": self.drop2 = self.drop1 else: raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
def test_act_layer(self, input_param, expected): layer = get_act_layer(**input_param) self.assertEqual(f"{layer}", expected)