Exemplo n.º 1
0
 def __init__(self, input_size, feature_size):
     super().__init__()
     self.input_layer = nn.Linear(input_size, feature_size, bias=False)
     self.hidden_layer_1 = nn.Linear(16, 16, bias=False)
     self.hidden_layer_2 = nn.Linear(128, 128, bias=False)
     #self.hidden_layer_3 = nn.Linear(32, 128, bias=False)
     self.output_layer = nn.Linear(16, feature_size, bias=False)
     self.prelu = nn.PReLU(1, 0.25)
     self.silu = nn.SiLU()
     self.elu = nn.ELU()
     self.tanshrink = nn.Tanhshrink()
Exemplo n.º 2
0
def get_activation_fn(activation_type, inplace=True):

    if activation_type == 'relu':
        relu = nn.ReLU(inplace=inplace)
    elif activation_type == "swish":
        relu = nn.SiLU(inplace=inplace)
    elif activation_type == "hardswish":
        relu = nn.Hardswish(inplace=inplace)
    else:
        raise ValueError(f'Unknown activation type: {activation_type}')
    return relu
Exemplo n.º 3
0
 def __init__(
         self,
         input_c: int,  # block input channel
         expand_c: int,  # block expand channel
         squeeze_factor: int = 4):
     super(SqueezeExcitation, self).__init__()
     squeeze_c = input_c // squeeze_factor
     self.fc1 = nn.Conv2d(expand_c, squeeze_c, 1)
     self.ac1 = nn.SiLU()  # alias Swish
     self.fc2 = nn.Conv2d(squeeze_c, expand_c, 1)
     self.ac2 = nn.Sigmoid()
Exemplo n.º 4
0
def get_activation_func(activation_type):
    if activation_type == 'relu':
        return nn.ReLU()
    elif activation_type == 'sigmoid':
        return nn.Sigmoid()
    elif activation_type == 'swish':
        return nn.SiLU()
    elif activation_type == 'mish':
        return Mish()
    else:
        raise ValueError('activation have unacceptable value')
Exemplo n.º 5
0
    def __init__(
        self,
        image_size,
        n_class,
        depths,
        dims,
        dim_head,
        n_heads,
        dim_ffs,
        window_size,
        halo_size,
        drop_ff=0,
        drop_attn=0,
        drop_path=0,
    ):
        super().__init__()

        self.depths = depths

        def make_block(i, in_dim, reduction):
            return self.make_block(
                depths[i],
                in_dim,
                dims[i],
                n_heads[i],
                dim_head,
                dim_ffs[i],
                window_size,
                halo_size,
                reduction,
                drop_ff,
                drop_attn,
                drop_path,
            )

        self.block1 = make_block(0, 3, 4)
        self.block2 = make_block(1, dims[0], 2)
        self.block3 = make_block(2, dims[1], 2)
        self.block4 = make_block(3, dims[2], 2)

        self.final_linear = nn.Sequential(
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], dims[-1] * 2),
            nn.LayerNorm(dims[-1] * 2),
            nn.SiLU(inplace=True),
        )
        linear = nn.Linear(dims[-1] * 2, n_class)
        nn.init.normal_(linear.weight, std=0.01)
        nn.init.zeros_(linear.bias)
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1),
                                        linear)

        self.apply(self.init_weights)
Exemplo n.º 6
0
    def forward(self, xin):

        x = self.conv1(xin)
        x = nn.SiLU()(x)
        if self.level == 1:
            x = self.conv2(x)
            return x

        x = self.conv3(x)
        x = nn.SiLU()(x)
        if self.level <= 3:
            x = self.conv4(x)
            return x

        x = self.conv5(x)
        x = nn.SiLU()(x)
        if self.level <= 4:
            x = self.conv6(x)
            return x
        else:
            raise Exception('No level %d' % self.level)
Exemplo n.º 7
0
    def __init__(self, c1, reduction=16):
        super(ChannelAttentionModule, self).__init__()
        mid_channel = c1 // reduction
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Linear(in_features=c1, out_features=mid_channel),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(in_features=mid_channel, out_features=c1))
        # self.sigmoid = nn.Sigmoid()
        self.act = nn.SiLU()
Exemplo n.º 8
0
 def __init__(self, dim=100, r=2.):
     super().__init__()
     intermediate = int(dim * r)
     self.attn = nn.Sequential(
         nn.Linear(dim, intermediate),
         nn.Dropout(0.01),
         nn.LayerNorm(intermediate),
         nn.SiLU(),
         nn.Linear(intermediate, 1),
         nn.Softmax(1),
     )
     self.attn.apply(init_weights)
Exemplo n.º 9
0
 def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
     super().__init__()
     self.conv = nn.Conv2d(c1,
                           c2,
                           k,
                           s,
                           autopad(k, p),
                           groups=g,
                           bias=False)
     self.bn = nn.BatchNorm2d(c2)
     self.act = nn.SiLU() if act is True else (
         act if isinstance(act, nn.Module) else nn.Identity())
Exemplo n.º 10
0
 def __init__(
     self,
     input_c: int,  # block input channel  ; 这里的input对应的是MBConv模块的输入channel
     expand_c: int,  # block expand channel ; 1*1卷积升维后,DW卷积并不改变维度
     #* MBConv模块 = 1*1卷积升维 + DW卷积 + SE模块 + 1*1卷积降维 + Dropout
     squeeze_factor: int = 4):
     super(SqueezeExcitation, self).__init__()
     squeeze_c = input_c // squeeze_factor  # 看到细节了吧,是input_c/4;而不是expand_c/4,
     self.fc1 = nn.Conv2d(expand_c, squeeze_c, 1)
     self.ac1 = nn.SiLU()  # alias Swish
     self.fc2 = nn.Conv2d(squeeze_c, expand_c, 1)
     self.ac2 = nn.Sigmoid()
Exemplo n.º 11
0
def down_conv(in_channels: int, out_channels: int, kernel_size: int,
              stride: int, padding: int) -> nn.Module:
    """Down convolutions."""
    return nn.Sequential(
        nn.Conv2d(in_channels,
                  out_channels,
                  kernel_size=kernel_size,
                  stride=stride,
                  padding=padding),
        nn.GroupNorm(num_groups=1, num_channels=out_channels),
        nn.SiLU(),
    )
Exemplo n.º 12
0
    def __init__(self,
                 in_node_nf,
                 hidden_nf,
                 out_node_nf,
                 in_edge_nf=0,
                 device='cpu',
                 act_fn=nn.SiLU(),
                 n_layers=4,
                 residual=True,
                 attention=False,
                 normalize=False,
                 tanh=False):
        '''

        :param in_node_nf: Number of features for 'h' at the input
        :param hidden_nf: Number of hidden features
        :param out_node_nf: Number of features for 'h' at the output
        :param in_edge_nf: Number of features for the edge features
        :param device: Device (e.g. 'cpu', 'cuda:0',...)
        :param act_fn: Non-linearity
        :param n_layers: Number of layer for the EGNN
        :param residual: Use residual connections, we recommend not changing this one
        :param attention: Whether using attention or not
        :param normalize: Normalizes the coordinates messages such that:
                    instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)
                    we get:     x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j||
                    We noticed it may help in the stability or generalization in some future works.
                    We didn't use it in our paper.
        :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of
                        phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy.
                        We didn't use it in our paper.
        '''

        super(EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module(
                "gcl_%d" % i,
                E_GCL(self.hidden_nf,
                      self.hidden_nf,
                      self.hidden_nf,
                      edges_in_d=in_edge_nf,
                      act_fn=act_fn,
                      residual=residual,
                      attention=attention,
                      normalize=normalize,
                      tanh=tanh))
        self.to(self.device)
Exemplo n.º 13
0
 def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1):
     super().__init__()
     self.cnn = nn.Conv2d(
         in_channels,
         out_channels,
         kernel_size,
         stride,
         padding,
         groups=groups,
         bias=False,
     )
     self.bn = nn.BatchNorm2d(out_channels)
     self.silu = nn.SiLU()
Exemplo n.º 14
0
 def __init__(self,
              c1,
              c2,
              k=1,
              s=1,
              p=None,
              g=1,
              act=True):  # ch_in, ch_out, kernel, stride, padding, groups
     super(DOConvComp, self).__init__()
     self.conv = DOConv2d(c1, c2, k, s, autopad(k, p), bias=False)
     self.bn = nn.BatchNorm2d(c2)
     self.act = nn.SiLU() if act is True else (
         act if isinstance(act, nn.Module) else nn.Identity())
 def _block(in_features, out_features):
     return nn.Sequential(
         OrderedDict([
             ("conv1",
              nn.Conv2d(in_channels=in_features,
                        out_channels=out_features,
                        kernel_size=3,
                        padding=1,
                        bias=False)),
             ("norm1", nn.BatchNorm2d(num_features=out_features)),
             #("relu1", nn.ReLU(inplace=True)),
             ("swish1", nn.SiLU(inplace=True)),
             ("conv2",
              nn.Conv2d(in_channels=out_features,
                        out_channels=out_features,
                        kernel_size=3,
                        padding=1,
                        bias=False)),
             ("norm2", nn.BatchNorm2d(num_features=out_features)),
             #("relu2", nn.ReLU(inplace=True))
             ("swish2", nn.SiLU(inplace=True))
         ]))
Exemplo n.º 16
0
def get_act(config):
    """Get activation functions from the config file."""

    if config.model.nonlinearity.lower() == 'elu':
        return nn.ELU()
    elif config.model.nonlinearity.lower() == 'relu':
        return nn.ReLU()
    elif config.model.nonlinearity.lower() == 'lrelu':
        return nn.LeakyReLU(negative_slope=0.2)
    elif config.model.nonlinearity.lower() == 'swish':
        return nn.SiLU()
    else:
        raise NotImplementedError('activation function does not exist!')
 def __init__(self,
              in_ch,
              out_ch,
              kernel=3,
              stride=1,
              padding=1,
              pool=False):  # ch_in, ch_out, kernel, stride, padding, groups
     super(ConvBlock, self).__init__()
     self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride, padding)
     self.bn = nn.BatchNorm2d(out_ch)
     self.act = nn.SiLU()
     self.Maxpool = nn.MaxPool2d(4)
     self.pool = pool
Exemplo n.º 18
0
 def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1):
     super(CNNBlock, self).__init__()
     self.cnn = nn.Conv2d(
         in_channels=in_channels,
         out_channels=out_channels,
         kernel_size=kernel_size,
         stride=stride,
         padding=padding,
         groups=groups,
         bias=False
     )
     self.bn = nn.BatchNorm2d(out_channels)
     self.silu = nn.SiLU() # SiLU = Sigmoid Linear Units
Exemplo n.º 19
0
    def __init__(self, channel=512):
        super().__init__()
        self.sse = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                 nn.Conv2d(channel, channel, kernel_size=1),
                                 nn.Sigmoid())

        self.conv1x1 = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=1),
            nn.BatchNorm2d(channel))
        self.conv3x3 = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(channel))
        self.silu = nn.SiLU()
Exemplo n.º 20
0
    def __init__(
        self,
        num_in_feats,
        num_out_feats,
        kernel_size,
        stride,
        output_size,
        reduction,  # reduction factor for squeeze excitation
        Norm,
        Conv,
        Pool,
    ):
        super().__init__()

        # depthwise conv -> batch norm -> swish
        # SE
        # conv -> batch norm

        self.add_module(
            "depthconv1",
            Conv(
                num_in_feats,
                num_in_feats,
                groups=num_in_feats,
                kernel_size=kernel_size,
                stride=stride,
                padding=int((kernel_size - 1) / 2),
                bias=False,
            ),
        )
        self.add_module("norm1", Norm(num_in_feats))
        self.add_module("silu", nn.SiLU(inplace=True))

        # ADD SQUEEZE EXCITATION; no change in dimensions
        self.add_module(
            "squeeze",
            _SqueezeExcitation(num_in_feats, reduction, output_size, Pool))

        self.add_module(
            "conv2",
            Conv(
                num_in_feats,
                num_out_feats,
                kernel_size=1,
                stride=1,
                padding=int((kernel_size - 1) / 2),
                bias=False,
            ),
        )

        self.add_module("norm2", Norm(num_out_feats))
Exemplo n.º 21
0
    def __init__(self,
                 model_size,
                 inner_size,
                 dropout=0.,
                 variational=False,
                 activation='relu',
                 n_languages=1,
                 rank=1,
                 use_multiplicative=False,
                 weight_drop=0.0,
                 mfw_activation='none',
                 glu=False,
                 no_bias=False):
        super().__init__()

        self.variational = variational
        self.dropout = dropout
        self.activation = activation
        self.n_languages = n_languages
        self.weight_drop = weight_drop
        self.glu = glu

        self.input_linear = MultilingualLinear(model_size,
                                               inner_size * (2 if glu else 1),
                                               n_languages,
                                               rank,
                                               use_multiplicative,
                                               weight_drop,
                                               mfw_activation=mfw_activation,
                                               no_bias=no_bias)
        self.output_linear = MultilingualLinear(inner_size,
                                                model_size,
                                                n_languages,
                                                rank,
                                                use_multiplicative,
                                                weight_drop,
                                                mfw_activation=mfw_activation,
                                                no_bias=no_bias)

        if self.activation == 'relu':
            self.act = nn.ReLU(inplace=True)
        elif self.activation == 'gelu':
            self.act = nn.GELU()
        elif self.activation in ['silu', 'swish']:
            self.act = nn.SiLU(inplace=True)

        if self.variational:
            from onmt.modules.dropout import variational_dropout
            self.dropout_function = variational_dropout
        else:
            self.dropout_function = F.dropout
Exemplo n.º 22
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=64,
                 num_blocks=23,
                 growth_channels=32,
                 body_block=RRDB,
                 ):
        super(RRDBNet, self).__init__()
        self.num_blocks = num_blocks
        self.in_channels = in_channels
        self.mid_channels = mid_channels

        # The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space
        self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True)
        self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
        self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)

        # Guided diffusion uses a time embedding.
        time_embed_dim = mid_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(mid_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

        self.body = make_layer(
            body_block,
            num_blocks,
            mid_channels=mid_channels,
            growth_channels=growth_channels)

        self.conv_body = nn.Conv2d(self.mid_channels, self.mid_channels, 3, 1, 1)
        # upsample
        self.conv_up1 = nn.Conv2d(self.mid_channels, self.mid_channels, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(self.mid_channels*2, self.mid_channels, 3, 1, 1)
        self.conv_up3 = None
        self.conv_hr = nn.Conv2d(self.mid_channels*2, self.mid_channels, 3, 1, 1)
        self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels)

        for m in [
            self.conv_body, self.conv_up1,
            self.conv_up2, self.conv_hr
        ]:
            if m is not None:
                default_init_weights(m, 1.0)
        default_init_weights(self.conv_last, 0)
Exemplo n.º 23
0
    def __init__(self, inp, oup, stride, expand_ratio, use_se):
        super(MBConv, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup
        if use_se:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(inplace=True),
                # dw
                nn.Conv2d(hidden_dim,
                          hidden_dim,
                          3,
                          stride,
                          1,
                          groups=hidden_dim,
                          bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(inplace=True),
                SELayer(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # fused
                nn.Conv2d(inp, hidden_dim, 3, stride, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
Exemplo n.º 24
0
 def __init__(self, act_type, auto_optimize=True, **kwargs):
     super(Activation, self).__init__()
     if act_type == 'relu':
         self.act = nn.ReLU(inplace=True) if auto_optimize else nn.ReLU(
             **kwargs)
     elif act_type == 'relu6':
         self.act = nn.ReLU6(inplace=True) if auto_optimize else nn.ReLU6(
             **kwargs)
     elif act_type == 'h_swish':
         self.act = nn.Hardswish(
             inplace=True) if auto_optimize else nn.Hardswish(**kwargs)
     elif act_type == 'h_sigmoid':
         self.act = nn.Hardsigmoid(
             inplace=True) if auto_optimize else nn.Hardsigmoid(**kwargs)
     elif act_type == 'swish':
         self.act = nn.SiLU(inplace=True) if auto_optimize else nn.SiLU(
             **kwargs)
     elif act_type == 'gelu':
         self.act = nn.GELU()
     elif act_type == 'quick_gelu':
         self.act = QuickGELU()
     elif act_type == 'elu':
         self.act = nn.ELU(inplace=True, **
                           kwargs) if auto_optimize else nn.ELU(**kwargs)
     elif act_type == 'mish':
         self.act = Mish()
     elif act_type == 'sigmoid':
         self.act = nn.Sigmoid()
     elif act_type == 'lrelu':
         self.act = nn.LeakyReLU(inplace=True, **
                                 kwargs) if auto_optimize else nn.LeakyReLU(
                                     **kwargs)
     elif act_type == 'prelu':
         self.act = nn.PReLU(**kwargs)
     else:
         raise NotImplementedError(
             '{} activation is not implemented.'.format(act_type))
Exemplo n.º 25
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 bias=False,
                 bn=True,
                 act="relu"):
        """
        conv + bn + act
        :param in_channels:     输入通道数
        :param out_channels:    输出通道数
        :param kernel_size:     卷积核大小
        :param stride:
        :param padding:
        :param bias:
        :param bn:              True or False
        :param act:             "relu", "leaky_0.1", "silu"
        """
        super(ConvBnAct, self).__init__()
        self.layers = nn.Sequential()
        self.layers.add_module(
            "conv",
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size,
                      stride,
                      padding,
                      bias=bias))

        if bn:
            self.layers.add_module("bn", nn.BatchNorm2d(out_channels))
        if act == "linear":
            pass
        elif act == "relu":
            self.layers.add_module("act", nn.ReLU(True))
        elif act == "silu":
            self.layers.add_module("act", nn.SiLU(True))
        elif act == "hardwish":
            self.layers.add_module("act", nn.Hardswish(True))
        elif act.startswith("leaky"):
            # eg : leaky_0.1
            negative_slope = float(act.split('_')[-1])
            self.layers.add_module("act", nn.LeakyReLU(negative_slope, True))
        else:
            raise ValueError(
                " Activation function '{}' is not supported, you can add it here!"
                .format(act))
Exemplo n.º 26
0
    def __init__(self, data_dimension, hidden_layer_sizes):
        super(Critic, self).__init__()

        layer_sizes = [data_dimension] + hidden_layer_sizes + [1]

        layers = []

        for i in range(len(layer_sizes) - 1):
            in_size = layer_sizes[i]
            out_size = layer_sizes[i + 1]
            layers.append(nn.Linear(in_size, out_size))
            if i != len(layer_sizes) - 2:
                #                layers.append(nn.LeakyReLU(0.2))
                layers.append(nn.SiLU())
        self.network = nn.Sequential(*layers)
Exemplo n.º 27
0
 def __init__(self,
              in_channels,
              out_channels,
              kernel_size,
              stride,
              padding,
              activation=True):
     super(ConvBnSiLU, self).__init__()
     self.conv = nn.Conv2d(in_channels,
                           out_channels,
                           kernel_size,
                           stride,
                           padding,
                           bias=False)
     self.bn = nn.BatchNorm2d(out_channels, 1e-3, 0.03)
     self.act = nn.SiLU(inplace=True) if activation else nn.Identity()
Exemplo n.º 28
0
 def __init__(self, c_in, channel, d_model, dropout=0.1, data='ETT'):
   super(TokenEmbedding, self).__init__()
   self.data = data
   if not data.startswith('janestreet'):
     padding = 1 if torch.__version__>='1.5.0' else 2
     self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 
                                 kernel_size=3, padding=padding, padding_mode='circular')
     for m in self.modules():
       if isinstance(m, nn.Conv1d):
         nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
   else:
     self.dense1 = nn.Linear(c_in, d_model)
     self.bn1 = nn.BatchNorm1d(d_model)
     self.dense2 = nn.Linear(d_model, d_model)
     self.bn2 = nn.BatchNorm1d(d_model)
     self.act = nn.SiLU()
Exemplo n.º 29
0
 def __init__(self,
              c1,
              c2,
              n=1,
              shortcut=True,
              g=1,
              e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
     super().__init__()
     c_ = int(c2 * e)  # hidden channels
     self.cv1 = Conv(c1, c_, 1, 1)
     self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
     self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
     self.cv4 = Conv(2 * c_, c2, 1, 1)
     self.bn = nn.BatchNorm2d(2 * c_)  # applied to cat(cv2, cv3)
     self.act = nn.SiLU()
     self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0)
                              for _ in range(n)))
Exemplo n.º 30
0
    def create_stem(self, params: Union[RegNetParams, AnyNetParams]):
        # get the activation
        silu = None if get_torch_version() < [1, 7] else nn.SiLU()
        activation = {
            ActivationType.RELU: nn.ReLU(params.relu_in_place),
            ActivationType.SILU: silu,
        }[params.activation]

        # create stem
        stem = {
            StemType.RES_STEM_CIFAR: ResStemCifar,
            StemType.RES_STEM_IN: ResStemIN,
            StemType.SIMPLE_STEM_IN: SimpleStemIN,
        }[params.stem_type](3, params.stem_width, params.bn_epsilon,
                            params.bn_momentum, activation)
        init_weights(stem)
        return stem