def rgb_tf_efficientnet_b3_ns(num_classes=4,
                              pretrained=True,
                              dropout=0.1,
                              need_embedding=False):
    encoder = efficientnet.tf_efficientnet_b3_ns(pretrained=pretrained,
                                                 drop_path_rate=0.2)
    del encoder.classifier

    return TimmRgbModel(encoder,
                        num_classes=num_classes,
                        dropout=dropout,
                        need_embedding=need_embedding)
def hpf_b3_fixed_gap(num_classes, dropout=0, pretrained=False):
    from timm.models import efficientnet

    encoder = efficientnet.tf_efficientnet_b3_ns(pretrained=True, drop_path_rate=0.1)
    encoder.conv_stem = nn.Sequential(HPF3(trainable_hpf=False, stride=2), nn.Conv2d(30, 40, kernel_size=1))
    del encoder.classifier

    return HPFNetGAP(
        encoder,
        num_classes=num_classes,
        dropout=dropout,
        mean=encoder.default_cfg["mean"],
        std=encoder.default_cfg["std"],
    )
    def __init__(self, pretrained=True, layers=[1, 2, 3, 4], act_layer=Swish, no_stride=False):
        from timm.models.efficientnet import tf_efficientnet_b3_ns

        encoder = tf_efficientnet_b3_ns(
            pretrained=pretrained, features_only=True, act_layer=act_layer, drop_path_rate=0.1
        )
        strides = [2, 4, 8, 16, 32]
        if no_stride:
            encoder.blocks[5][0].conv_dw.stride = (1, 1)
            encoder.blocks[5][0].conv_dw.dilation = (2, 2)

            encoder.blocks[3][0].conv_dw.stride = (1, 1)
            encoder.blocks[3][0].conv_dw.dilation = (2, 2)
            strides[3] = 8
            strides[4] = 8
        super().__init__([24, 32, 48, 136, 384], strides, layers)
        self.encoder = encoder