Exemplo n.º 1
0
def _load_pretrained(model_name, model, progress):
    if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
        raise ValueError(
            "No checkpoint is available for model type {}".format(model_name))
    checkpoint_url = _MODEL_URLS[model_name]
    model.load_state_dict(
        load_state_dict_from_url(checkpoint_url, progress=progress))
Exemplo n.º 2
0
 def _initialize(self):
     if self.pretrained:
         state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                               progress=False)
         self.load_state_dict(state_dict)
     else:
         raise NotImplementedError('Need routine for init from scratch.')
Exemplo n.º 3
0
    def _initialize(self):
        if self.pretrained:
            state_dict = load_state_dict_from_url(model_urls[self.arch],
                                                  progress=False)
            self.load_state_dict(state_dict)
        else:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight,
                                            mode='fan_out',
                                            nonlinearity='relu')  # noqa
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                    nn.init.constant_(m.running_var, 1)
                    nn.init.constant_(m.running_mean, 0)

            # Zero-initialize the last BN in each residual branch,
            # so that the residual branch starts with zeros, and each residual block behaves like an identity.  # noqa
            # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677  # noqa
            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        nn.init.constant_(m.bn3.weight, 0)
                    elif isinstance(m, BasicBlock):
                        nn.init.constant_(m.bn2.weight, 0)
Exemplo n.º 4
0
def inception_v3(siamese, version, retrained=False, progress=True, **kwargs):
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        model = Inception3(siamese=siamese, version=version, **kwargs)
        state_dict = load_state_dict_from_url(
            model_urls['inception_v3_google'], progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

    return Inception3(siamese=siamese, version=version, **kwargs)
Exemplo n.º 5
0
def _resnet(arch, block, layers, pretrained, progress, siamese, **kwargs):
    model = ResNet(arch=arch,
                   pretrained=pretrained,
                   block=block,
                   layers=layers,
                   siamese=siamese,
                   **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        # state_dict = remove_bnstats(state_dict)
        model.load_state_dict(state_dict)
    return model
Exemplo n.º 6
0
 def _initialize(self):
     if self.pretrained:
         state_dict = load_state_dict_from_url(model_urls[self.arch],
                                               progress=False)
         self.load_state_dict(state_dict)
     else:
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 nn.init.kaiming_normal_(m.weight)
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
             elif isinstance(m, nn.Linear):
                 nn.init.constant_(m.bias, 0)
Exemplo n.º 7
0
def alexnet(siamese, version, pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(siamese=siamese, version=version, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
Exemplo n.º 8
0
def _vgg(arch, cfg, batch_norm, pretrained, progress, siamese, version,
         **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm),
                siamese=siamese,
                version=version,
                arch=arch,
                pretrained=pretrained,
                **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
Exemplo n.º 9
0
def _load_state_dict(model, model_url, progress):
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)
Exemplo n.º 10
0
 def _initialize(self):
     if self.pretrained:
         state_dict = load_state_dict_from_url(model_urls[self.arch],
                                               progress=False)
         self.load_state_dict(state_dict)
     else:
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 nn.init.kaiming_normal_(m.weight,
                                         mode="fan_out",
                                         nonlinearity="relu")
                 if m.bias is not None:
                     nn.init.zeros_(m.bias)
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.ones_(m.weight)
                 nn.init.zeros_(m.bias)
             elif isinstance(m, nn.Linear):
                 nn.init.normal_(m.weight, 0.01)
                 nn.init.zeros_(m.bias)