예제 #1
0
def _linknet(
    arch: str,
    pretrained: bool,
    backbone_fn,
    fpn_layers: List[str],
    pretrained_backbone: bool = True,
    input_shape: Optional[Tuple[int, int, int]] = None,
    **kwargs: Any,
) -> LinkNet:

    pretrained_backbone = pretrained_backbone and not pretrained

    # Patch the config
    _cfg = deepcopy(default_cfgs[arch])
    _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]

    # Feature extractor
    feat_extractor = IntermediateLayerGetter(
        backbone_fn(
            pretrained=pretrained_backbone,
            include_top=False,
            input_shape=_cfg["input_shape"],
        ),
        fpn_layers,
    )

    # Build the model
    model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, _cfg["url"])

    return model
예제 #2
0
def test_load_pretrained_params(tmpdir_factory):

    model = Sequential([
        layers.Dense(8, activation="relu", input_shape=(4, )),
        layers.Dense(4)
    ])
    # Retrieve this URL
    url = "https://github.com/mindee/doctr/releases/download/v0.1-models/tmp_checkpoint-4a98e492.zip"
    # Temp cache dir
    cache_dir = tmpdir_factory.mktemp("cache")
    # Pass an incorrect hash
    with pytest.raises(ValueError):
        load_pretrained_params(model,
                               url,
                               "mywronghash",
                               cache_dir=str(cache_dir),
                               internal_name="")
    # Let tit resolve the hash from the file name
    load_pretrained_params(model,
                           url,
                           cache_dir=str(cache_dir),
                           internal_name="")
    # Check that the file was downloaded & the archive extracted
    assert os.path.exists(
        cache_dir.join("models").join("tmp_checkpoint-4a98e492"))
    # Check that archive was deleted
    assert os.path.exists(
        cache_dir.join("models").join("tmp_checkpoint-4a98e492.zip"))
예제 #3
0
def _db_resnet(
    arch: str,
    pretrained: bool,
    backbone_fn,
    fpn_layers: List[str],
    pretrained_backbone: bool = True,
    input_shape: Optional[Tuple[int, int, int]] = None,
    **kwargs: Any,
) -> DBNet:

    pretrained_backbone = pretrained_backbone and not pretrained

    # Patch the config
    _cfg = deepcopy(default_cfgs[arch])
    _cfg['input_shape'] = input_shape or _cfg['input_shape']

    # Feature extractor
    feat_extractor = IntermediateLayerGetter(
        backbone_fn(
            weights='imagenet' if pretrained_backbone else None,
            include_top=False,
            pooling=None,
            input_shape=_cfg['input_shape'],
        ),
        fpn_layers,
    )

    # Build the model
    model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
    # Load pretrained parameters
    if pretrained:
        load_pretrained_params(model, _cfg['url'])

    return model
예제 #4
0
def test_load_pretrained_params(tmpdir_factory):

    model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4))
    # Retrieve this URL
    url = "https://github.com/mindee/doctr/releases/download/v0.2.1/tmp_checkpoint-6f0ce0e6.pt"
    # Temp cache dir
    cache_dir = tmpdir_factory.mktemp("cache")
    # Pass an incorrect hash
    with pytest.raises(ValueError):
        load_pretrained_params(model,
                               url,
                               "mywronghash",
                               cache_dir=str(cache_dir))
    # Let tit resolve the hash from the file name
    load_pretrained_params(model, url, cache_dir=str(cache_dir))
    # Check that the file was downloaded & the archive extracted
    assert os.path.exists(
        cache_dir.join('models').join(url.rpartition("/")[-1]))