def __init__(self, pretrained=True, layers=[1, 2, 3, 4], act_layer: Union[str, nn.Module] = Swish, no_stride=False): from timm.models.efficientnet import tf_efficientnet_b0_ns if isinstance(act_layer, str): act_layer = get_activation_block(act_layer) encoder = tf_efficientnet_b0_ns(pretrained=pretrained, features_only=True, act_layer=act_layer, drop_path_rate=0.05) strides = [2, 4, 8, 16, 32] if no_stride: encoder.blocks[5][0].conv_dw.stride = (1, 1) encoder.blocks[5][0].conv_dw.dilation = (2, 2) encoder.blocks[3][0].conv_dw.stride = (1, 1) encoder.blocks[3][0].conv_dw.dilation = (2, 2) strides[3] = 8 strides[4] = 8 super().__init__([16, 24, 40, 112, 320], strides, layers) self.encoder = encoder
def __init__(self, pretrained=True, layers=[1, 2, 3, 4], act_layer: Union[str, nn.Module] = Swish): from timm.models.efficientnet import mixnet_xl if isinstance(act_layer, str): act_layer = get_activation_block(act_layer) encoder = mixnet_xl(pretrained=pretrained, features_only=True, act_layer=act_layer, drop_path_rate=0.2) super().__init__([40, 48, 64, 192, 320], [2, 4, 8, 16, 32], layers) self.encoder = encoder
def __init__( self, input_channels: int = 3, stack_level: int = 8, depth: int = 4, features: int = 256, activation=ACT_RELU, repeats=1, pooling_block=nn.MaxPool2d, ): super().__init__( channels=[features] + [features] * stack_level, strides=[4] + [4] * stack_level, layers=list(range(0, stack_level + 1)), ) self.stack_level = stack_level self.depth_level = depth self.num_features = features act = get_activation_block(activation) self.stem = HGStemBlock(input_channels, features, activation=act) input_features = features modules = [] for _ in range(stack_level): modules.append( HGBlock( depth, input_features, features, increase=0, activation=act, repeats=repeats, pooling_block=pooling_block, )) input_features = features self.num_blocks = len(modules) self.blocks = nn.ModuleList(modules) self.features = nn.ModuleList([ HGFeaturesBlock(features, blocks=4, activation=act) for _ in range(stack_level) ]) self.merge_features = nn.ModuleList([ nn.Conv2d(features, features, kernel_size=1) for _ in range(stack_level - 1) ])