예제 #1
0
def _linknet(
    arch: str,
    pretrained: bool,
    backbone_fn,
    fpn_layers: List[str],
    pretrained_backbone: bool = True,
    input_shape: Optional[Tuple[int, int, int]] = None,
    **kwargs: Any,
) -> LinkNet:

    pretrained_backbone = pretrained_backbone and not pretrained

    # Patch the config
    _cfg = deepcopy(default_cfgs[arch])
    _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]

    # Feature extractor
    feat_extractor = IntermediateLayerGetter(
        backbone_fn(
            pretrained=pretrained_backbone,
            include_top=False,
            input_shape=_cfg["input_shape"],
        ),
        fpn_layers,
    )

    # Build the model
    model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, _cfg["url"])

    return model
예제 #2
0
def _db_resnet(
    arch: str,
    pretrained: bool,
    backbone_fn,
    fpn_layers: List[str],
    pretrained_backbone: bool = True,
    input_shape: Optional[Tuple[int, int, int]] = None,
    **kwargs: Any,
) -> DBNet:

    pretrained_backbone = pretrained_backbone and not pretrained

    # Patch the config
    _cfg = deepcopy(default_cfgs[arch])
    _cfg['input_shape'] = input_shape or _cfg['input_shape']

    # Feature extractor
    feat_extractor = IntermediateLayerGetter(
        backbone_fn(
            weights='imagenet' if pretrained_backbone else None,
            include_top=False,
            pooling=None,
            input_shape=_cfg['input_shape'],
        ),
        fpn_layers,
    )

    # Build the model
    model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, _cfg['url'])

    return model
예제 #3
0
def test_intermediate_layer_getter():
    backbone = ResNet50(include_top=False, weights=None, pooling=None)
    feat_extractor = IntermediateLayerGetter(
        backbone, ["conv2_block3_out", "conv3_block4_out"])
    # Check num of output features
    input_tensor = tf.random.uniform(shape=[1, 224, 224, 3],
                                     minval=0,
                                     maxval=1)
    assert len(feat_extractor(input_tensor)) == 2

    # Repr
    assert repr(feat_extractor) == "IntermediateLayerGetter()"