def _load_pretrained(model_name, model, progress): if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: raise ValueError( "No checkpoint is available for model type {}".format(model_name)) checkpoint_url = _MODEL_URLS[model_name] model.load_state_dict( load_state_dict_from_url(checkpoint_url, progress=progress))
def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
def googlenet(pretrained=False, progress=True, **kwargs): r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: *False* """ if pretrained: if 'transform_input' not in kwargs: kwargs['transform_input'] = True if 'aux_logits' not in kwargs: kwargs['aux_logits'] = False if kwargs['aux_logits']: warnings.warn( 'auxiliary heads in the pretrained googlenet model are NOT pretrained, ' 'so make sure to train them') original_aux_logits = kwargs['aux_logits'] kwargs['aux_logits'] = True kwargs['init_weights'] = False model = GoogLeNet(**kwargs) state_dict = load_state_dict_from_url(model_urls['googlenet'], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False model.aux1 = None model.aux2 = None return model return GoogLeNet(**kwargs)
def _squeezenet(version, pretrained, progress, **kwargs): model = SqueezeNet(version, **kwargs) if pretrained: arch = 'squeezenet' + version state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): if pretrained: kwargs['init_weights'] = False model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): model = ShuffleNetV2(*args, **kwargs) if pretrained: model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model
def alexnet(pretrained=False, progress=True, **kwargs): r"""AlexNet model architecture from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ model = AlexNet(**kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['alexnet'], progress=progress) model.load_state_dict(state_dict) return model
def mobilenet_v2(pretrained=False, progress=True, **kwargs): """ Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ model = MobileNetV2(**kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], progress=progress) model.load_state_dict(state_dict) return model
def _load_state_dict(model, model_url, progress): # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' ) state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict)
def inception_v3(pretrained=False, progress=True, **kwargs): r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_. .. note:: **Important**: In contrast to the other models the inception_v3 expects tensors with a size of N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: *False* """ if pretrained: if 'transform_input' not in kwargs: kwargs['transform_input'] = True if 'aux_logits' in kwargs: original_aux_logits = kwargs['aux_logits'] kwargs['aux_logits'] = True else: original_aux_logits = True kwargs[ 'init_weights'] = False # we are loading weights from a pretrained model model = Inception3(**kwargs) state_dict = load_state_dict_from_url( model_urls['inception_v3_google'], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False del model.AuxLogits return model return Inception3(**kwargs)