def _linknet( arch: str, pretrained: bool, backbone_fn, fpn_layers: List[str], pretrained_backbone: bool = True, input_shape: Optional[Tuple[int, int, int]] = None, **kwargs: Any, ) -> LinkNet: pretrained_backbone = pretrained_backbone and not pretrained # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"] # Feature extractor feat_extractor = IntermediateLayerGetter( backbone_fn( pretrained=pretrained_backbone, include_top=False, input_shape=_cfg["input_shape"], ), fpn_layers, ) # Build the model model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: load_pretrained_params(model, _cfg["url"]) return model
def test_load_pretrained_params(tmpdir_factory): model = Sequential([ layers.Dense(8, activation="relu", input_shape=(4, )), layers.Dense(4) ]) # Retrieve this URL url = "https://github.com/mindee/doctr/releases/download/v0.1-models/tmp_checkpoint-4a98e492.zip" # Temp cache dir cache_dir = tmpdir_factory.mktemp("cache") # Pass an incorrect hash with pytest.raises(ValueError): load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir), internal_name="") # Let tit resolve the hash from the file name load_pretrained_params(model, url, cache_dir=str(cache_dir), internal_name="") # Check that the file was downloaded & the archive extracted assert os.path.exists( cache_dir.join("models").join("tmp_checkpoint-4a98e492")) # Check that archive was deleted assert os.path.exists( cache_dir.join("models").join("tmp_checkpoint-4a98e492.zip"))
def _db_resnet( arch: str, pretrained: bool, backbone_fn, fpn_layers: List[str], pretrained_backbone: bool = True, input_shape: Optional[Tuple[int, int, int]] = None, **kwargs: Any, ) -> DBNet: pretrained_backbone = pretrained_backbone and not pretrained # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg['input_shape'] = input_shape or _cfg['input_shape'] # Feature extractor feat_extractor = IntermediateLayerGetter( backbone_fn( weights='imagenet' if pretrained_backbone else None, include_top=False, pooling=None, input_shape=_cfg['input_shape'], ), fpn_layers, ) # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: load_pretrained_params(model, _cfg['url']) return model
def test_load_pretrained_params(tmpdir_factory): model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) # Retrieve this URL url = "https://github.com/mindee/doctr/releases/download/v0.2.1/tmp_checkpoint-6f0ce0e6.pt" # Temp cache dir cache_dir = tmpdir_factory.mktemp("cache") # Pass an incorrect hash with pytest.raises(ValueError): load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir)) # Let tit resolve the hash from the file name load_pretrained_params(model, url, cache_dir=str(cache_dir)) # Check that the file was downloaded & the archive extracted assert os.path.exists( cache_dir.join('models').join(url.rpartition("/")[-1]))