예제 #1
0
def mobilenet_v2(pretrained=False, checkpoints=None, progress=True, **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2(**kwargs)
    if pretrained:
        if checkpoints is not None:
            model.load_state_dict(torch.load(checkpoints), True)
            return model
        state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        return model
    return model
예제 #2
0
def _vgg(arch,
         cfg,
         batch_norm,
         pretrained=False,
         checkpoints=None,
         progress=True,
         **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        if checkpoints is not None:
            model.load_state_dict(torch.load(checkpoints), True)
            return model
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
        return model
    return model
예제 #3
0
def vgg(depth, batch_norm, num_classes, pretrained):

    model = VGG(make_layers(cfgs[depth], batch_norm=batch_norm),
                num_classes,
                init_weights=True)
    arch = 'vgg' + str(depth)
    if batch_norm == True: arch += '_bn'

    if pretrained and (num_classes == 1000) and (arch
                                                 in pretrained_model_urls):
        state_dict = load_state_dict_from_url(pretrained_model_urls[arch],
                                              progress=True)
        model.load_state_dict(state_dict)
    elif pretrained:
        raise ValueError(
            'No pretrained model in vggnet {} model with class number {}'.
            format(depth, num_classes))

    return model
예제 #4
0
def dpn131(pretrained=False, test_time_pool=False, **kwargs):
    """Constructs a DPN-131 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet-1K
        test_time_pool (bool): If True, pools features for input resolution beyond
            standard 224x224 input with avg+max at inference/validation time
        **kwargs : Keyword args passed to model __init__
            num_classes (int): Number of classes for classifier linear layer, default=1000
    """
    model = DPN(num_init_features=128,
                k_r=160,
                groups=40,
                k_sec=(4, 8, 28, 3),
                inc_sec=(16, 32, 32, 128),
                test_time_pool=test_time_pool,
                **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(model_urls['dpn131']))
    return model
예제 #5
0
 def _build_vgg_loss(self, avg_pool, feature_norm, weights, device):
     self.content_losses, self.style_losses = {}, {}
     self.vgg_loss = nn.Sequential()
     vgg = models.vgg19(pretrained=False).features
     if weights in ('original', 'normalized'):
         state_dict = load_state_dict_from_url(
             'https://storage.googleapis'
             f'.com/prism-weights/vgg19-{weights}.pth')
     else:
         state_dict = torch.load(weights)
     vgg.load_state_dict(state_dict)
     vgg = vgg.eval()
     for param in vgg.parameters():
         param.requires_grad_(False)
     i_pool, i_conv = 1, 0
     for layer in vgg.children():
         if isinstance(layer, nn.Conv2d):
             i_conv += 1
             name = f'conv_{i_pool}_{i_conv}'
         elif isinstance(layer, nn.ReLU):
             name = f'relu_{i_pool}_{i_conv}'
             layer = nn.ReLU(inplace=False)
         elif isinstance(layer, nn.MaxPool2d):
             name = f'pool_{i_pool}'
             if avg_pool:
                 layer = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
             i_pool += 1
             i_conv = 0
         self.vgg_loss.add_module(name, layer)
         if name in self.content_weights:
             content_loss = ContentLoss('none')
             self.vgg_loss.add_module(f'content_loss_{i_pool}_{i_conv}',
                                      content_loss)
             self.content_losses[name] = content_loss
         if name in self.style_weights:
             style_loss = StyleLoss('none', feature_norm)
             self.vgg_loss.add_module(f'style_loss_{i_pool}_{i_conv}',
                                      style_loss)
             self.style_losses[name] = style_loss
         if (len(self.style_weights) == len(self.style_losses)
                 and len(self.content_weights) == len(self.content_losses)):
             break
     self.vgg_loss.to(device)
예제 #6
0
파일: mobilenet.py 프로젝트: baabp/SaTorch
def mobilenet_v2(num_classes,
                 in_channels=3,
                 pretrained=False,
                 progress=True,
                 **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    # if pretrained and in_channels != 3:
    #     raise ValueError('ImageNet pretrained models only support 3 input channels, but got {}'.format(in_channels))

    if pretrained:
        model = MobileNetV2(**kwargs)
        state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        conv0 = model.features[0][0]
        model.features[0][0] = nn.Conv2d(in_channels=in_channels,
                                         out_channels=conv0.out_channels,
                                         kernel_size=conv0.kernel_size,
                                         stride=conv0.stride,
                                         padding=conv0.padding,
                                         bias=conv0.bias)
        if in_channels <= 3:
            model.features[0][
                0].weight.data = conv0.weight[:, 0:in_channels, :, :]
        else:
            multi = in_channels // 3
            last = in_channels % 3
            model.features[0][0].weight.data = torch.cat(
                [conv0.weight for x in range(multi)], dim=1)
            model.features[0][0].weight.data = conv0.weight[:, :last, :, :]
        model.classifier[1] = nn.Linear(model.classifier[1].in_features,
                                        num_classes)
    else:
        model = MobileNetV2(num_classes=num_classes,
                            in_channels=in_channels,
                            **kwargs)
    return model
예제 #7
0
def AlexNetDANN(pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = DANN(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        
        # removing unused params
        state_dict.popitem("classifier.6.bias")
        state_dict.popitem("classifier.6.weight") 
        model.load_state_dict(state_dict,strict=False)
        model.update_weigth()

    return model
예제 #8
0
    def __init__(self, num_classes=1000, pretrained=False, **kwargs):

        super().__init__(block=models.resnet.BasicBlock,
                         layers=[2, 2, 2, 2],
                         num_classes=num_classes,
                         **kwargs)
        if pretrained:
            state_dict = load_state_dict_from_url(
                models.resnet.model_urls["resnet18"], progress=True)
            self.load_state_dict(state_dict)

        self.avgpool = nn.AvgPool2d((7, 7))
        self.last_conv = torch.nn.Conv2d(in_channels=self.fc.in_features,
                                         out_channels=num_classes,
                                         kernel_size=1)

        self.last_conv.weight.data.copy_(
            self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1))
        self.last_conv.bias.data.copy_(self.fc.bias.data)
예제 #9
0
def _darknet(arch: str, pretrained: bool, progress: bool, *args: Any,
             **kwargs: Any) -> DarkNet:
    """
    Constructs a DarkNet architecture from
    # TODO

    """
    model = DarkNet(*args, **kwargs)

    if pretrained:
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError(
                'pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)

    return model
예제 #10
0
def _get_model_by_name(model_name, classes=1000, pretrained=False):
    block_args_list, global_params = get_efficientnet_params(model_name, override_params={'num_classes': classes})
    model = EfficientNet(block_args_list, global_params)
    try:
        if pretrained:
            pretrained_state_dict = load_state_dict_from_url(IMAGENET_WEIGHTS[model_name])

            if classes != 1000:
                random_state_dict = model.state_dict()
                pretrained_state_dict['_fc.weight'] = random_state_dict['_fc.weight']
                pretrained_state_dict['_fc.bias'] = random_state_dict['_fc.bias']

            model.load_state_dict(pretrained_state_dict)

    except KeyError as e:
        print(f"NOTE: Currently model {e} doesn't have pretrained weights, therefore a model with randomly initialized"
              " weights is returned.")

    return model
예제 #11
0
def ResNextFPN50(is_pretrained = True, use_se=False):
    fpn = ResFPN(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, dilation=1, use_se=use_se)

    if is_pretrained is False:
        for m in fpn.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    else:
        print('Loading pretrained ResNeXt50 model with ImageNet..')
        state_dict = load_state_dict_from_url(model_urls['resnext50_32x4d'], progress=True)
        del state_dict['fc.weight']
        del state_dict['fc.bias']
        missing_keys = fpn.load_state_dict(state_dict, strict=False)
        print(missing_keys)

    return fpn
예제 #12
0
def _efficientdet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    kwargs.update(cfg_params)
    model = EfficientDet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                f"Using model pretrained for {cfg_settings['num_classes']} classes with {kwargs_cls} classes. Last layer is initialized randomly"
            )
            state_dict["cls_head_conv.1.weight"] = model.state_dict()[f"cls_head_conv.1.weight"]
            state_dict["cls_head_conv.1.bias"] = model.state_dict()["cls_head_conv.1.bias"]
        # strict=False to avoid error on extra bias in BiFPN
        model.load_state_dict(state_dict, strict=False)
    setattr(model, "pretrained_settings", cfg_settings)
    return model
def _load_state_dict(model, model_url, progress):
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
    )

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    # print(state_dict)
    # exit(-1)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)
예제 #14
0
def inception_v3(pretrained=False, checkpoints=None, progress=True, **kwargs):
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        kwargs[
            'init_weights'] = False  # we are loading weights from a pretrained model
        model = Inception3(**kwargs)
        if checkpoints is not None:
            model.load_state_dict(torch.load(checkpoints), True)
            if not original_aux_logits:
                model.aux_logits = False
                del model.AuxLogits
            return model
        state_dict = load_state_dict_from_url(
            model_urls['inception_v3_google'], progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

    return Inception3(**kwargs)
예제 #15
0
    def load_pretrained_weights(self):

        vgg16_weights = hub.load_state_dict_from_url(
            "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth")

        count_vgg = 0
        count_this = 0

        vggkeys = list(vgg16_weights.keys())
        thiskeys = list(self.state_dict().keys())

        corresp_map = []

        while (True):
            vggkey = vggkeys[count_vgg]
            thiskey = thiskeys[count_this]

            if "classifier" in vggkey:
                break

            while vggkey.split(".")[-1] not in thiskey:
                count_this += 1
                thiskey = thiskeys[count_this]

            corresp_map.append([vggkey, thiskey])
            count_vgg += 1
            count_this += 1

        mapped_weights = self.state_dict()
        for k_vgg, k_segnet in corresp_map:
            if (self.in_channels !=
                    3) and "features" in k_vgg and "conv1_1." not in k_segnet:
                mapped_weights[k_segnet] = vgg16_weights[k_vgg]
            elif (self.in_channels == 3) and "features" in k_vgg:
                mapped_weights[k_segnet] = vgg16_weights[k_vgg]

        try:
            self.load_state_dict(mapped_weights)
            print("Loaded VGG-16 weights in Segnet !")
        except:
            print("Error VGG-16 weights in Segnet !")
            raise
예제 #16
0
def fetch_model(name, map_location=None, **kwargs):
    """Fetch model from URL.

    Loads model or state dict from URL.

    Args:
        name: Model name hosted on `celldetection.org` or url. Urls must start with 'http'.
        map_location: A function, `torch.device`, string or a dict specifying how to remap storage locations.
        **kwargs: From the doc of `torch.models.utils.load_state_dict_from_url`.

    """
    url = name if name.startswith(
        'http') else f'https://celldetection.org/torch/models/{name}.pt'
    m = load_state_dict_from_url(url, map_location=map_location, **kwargs)
    if isinstance(m, dict) and 'cd.models' in m.keys():
        from .. import models
        conf = m['cd.models']
        m = getattr(models, conf['model'])(*conf['a'], **conf['kw'])
        m.load_state_dict(conf['state_dict'])
    return m
예제 #17
0
def create(arch, image_size, channels, pretrained, progress):
    """ Creates a specified GAN model

    Args:
        arch (str): Arch name of model.
        image_size (int): Number of image size.
        channels (int): Number of input channels.
        pretrained (bool): Load pretrained weights into the model.
        progress (bool): Show progress bar when downloading weights.

    Returns:
        PyTorch model.
    """
    model = Generator(image_size, channels)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress,
                                              map_location=torch.device("cpu"))
        model.load_state_dict(state_dict)
    return model
def _efficient_x3d(
    pretrained: bool = False,
    progress: bool = True,
    checkpoint_path: str = None,
    # Model params
    expansion: str = "XS",
    **kwargs: Any,
) -> nn.Module:

    model = create_x3d(
        expansion=expansion,
        **kwargs,
    )

    if pretrained and checkpoint_path is not None:
        state_dict = load_state_dict_from_url(checkpoint_path,
                                              progress=progress)
        model.load_state_dict(state_dict, strict=True)

    return model
예제 #19
0
파일: btresnet.py 프로젝트: riverfjs/recode
def _resnext(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
    # new_state_dict = {}
    new_state_dict = model.state_dict()
    print("11111", len(state_dict))
    for k, v in state_dict.items():
        # print(k)
        if k in new_state_dict.keys() and (
                not k.startswith('fc')
                or not k.startswith('layer4')):  # 不加载全连接层
            new_state_dict[k] = v
    #     # print("{}: {}".format(k,v.requires_grad))
    #   # elif k.startswith('fc') or k.startswith('layer4'):
    #   #   print("yes, ",k)
    # new_state_dict = model.state_dict()
    # new_state_dict.update(state_dict)
    print("22222", len(new_state_dict))
    model.load_state_dict(new_state_dict)
    return model
예제 #20
0
def resnext(depth, num_classes, pretrained):

    model = _resnext(mode=cfgs[depth][0],
                     block=cfgs[depth][1],
                     cardinality=cfgs[depth][2],
                     layers=cfgs[depth][3],
                     num_classes=num_classes)
    arch = 'resnext' + str(depth)

    if pretrained and (num_classes == 1000) and (arch
                                                 in pretrained_model_urls):
        state_dict = load_state_dict_from_url(pretrained_model_urls[arch],
                                              progress=True)
        model.load_state_dict(state_dict)
    elif pretrained:
        raise ValueError(
            'No pretrained model in resnext {} model with class number {}'.
            format(depth, num_classes))

    return model
예제 #21
0
def alexnetDANN(pretrained=False, progress=True, **kwargs):
    """
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = DANNModel(num_classes = 1000, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        model.load_state_dict(state_dict, strict = False)
     
    #we copy the pretrained weights of the classifier to the domain classifier
    model.discriminator[1].weight.data = model.classifier[1].weight.data.clone()
    model.discriminator[1].bias.data = model.classifier[1].bias.data.clone()

    model.discriminator[4].weight.data = model.classifier[4].weight.data.clone()
    model.discriminator[4].bias.data = model.classifier[4].bias.data.clone()

    return model
예제 #22
0
def dann_net(pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    net = DannNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        net.load_state_dict(state_dict, strict=False)
        net.dann_classifier[1].weight.data = net.classifier[
            1].weight.data.clone()
        net.dann_classifier[1].bias.data = net.classifier[1].bias.data.clone()
        net.dann_classifier[4].weight.data = net.classifier[
            4].weight.data.clone()
        net.dann_classifier[4].bias.data = net.classifier[4].bias.data.clone()

    return net
예제 #23
0
def _vgg(architecture, config, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False

    model = VGG(make_layers(cfgs[config], batch_norm=batch_norm), **kwargs)

    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[architecture],
                                              progress=progress)

        del state_dict["classifier.0.weight"]
        del state_dict["classifier.0.bias"]
        del state_dict["classifier.3.weight"]
        del state_dict["classifier.3.bias"]
        del state_dict["classifier.6.weight"]
        del state_dict["classifier.6.bias"]

        model.load_state_dict(state_dict)

    return model
예제 #24
0
def _hrnet(pretrained, checkpoints, progress, **kwargs):
    update_config(config, model_cfg[kwargs['model_name']])
    model = get_cls_net(config)
    if pretrained:
        if checkpoints is not None:
            model.load_state_dict(torch.load(checkpoints), True)
            return model
        state_dict = load_state_dict_from_url(model_cfg[kwargs['model_name'] +
                                                        '_url'])
        model.load_state_dict(state_dict)
        return model
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight,
                                    mode='fan_out',
                                    nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    return model
예제 #25
0
def _resnet(arch,
            block,
            layers,
            pretrained,
            progress,
            use_cbam=False,
            use_mixpool=False,
            **kwargs):
    model = ResNet(block,
                   layers,
                   use_cbam=use_cbam,
                   use_mixpool=use_mixpool,
                   **kwargs)
    if pretrained and use_cbam:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        new_state_dict = model.state_dict()
        new_state_dict.update(state_dict)
        model.load_state_dict(new_state_dict)
    return model
예제 #26
0
def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    custom_class_num: bool = False,
    **kwargs: Any
) -> ResNet:

    model = ResNet(block, layers, custom_class_num=custom_class_num, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    
    if custom_class_num:
        return ResNetFlex(model)
    
    return model
예제 #27
0
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2(**kwargs)
    if pretrained:
        org_state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                                  progress=progress)

        state_dict = {}
        for k1, k2 in zip(org_state_dict.keys(), model.state_dict().keys()):
            state_dict[k2] = org_state_dict[k1]
        model.load_state_dict(state_dict)

    return model
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2(**kwargs)

    if pretrained:
        try:
            from torch.hub import load_state_dict_from_url
        except ImportError:
            from torch.utils.model_zoo import load_url as load_state_dict_from_url
        state_dict = load_state_dict_from_url(
            'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
            progress=True)
        model.load_state_dict(state_dict)
    return model
예제 #29
0
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs["init_weights"] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)

        snn_state_dict = {}

        # compared to the ANN parameters, the modules are lifted
        # we modify the state dict accordingly here
        for key in state_dict:
            l = key.split(".")
            l.insert(-1, "lifted_module")
            new_key = ".".join(l)
            snn_state_dict[new_key] = state_dict[key]

        model.load_state_dict(snn_state_dict)
    return model
예제 #30
0
def _cvt(arch, pretrained, progress,
         num_layers, num_heads, mlp_ratio, embedding_dim,
         kernel_size=4, positional_embedding='learnable',
         *args, **kwargs):
    model = CVT(num_layers=num_layers,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                embedding_dim=embedding_dim,
                kernel_size=kernel_size,
                *args, **kwargs)

    if pretrained and arch in model_urls:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        if positional_embedding == 'learnable':
            state_dict = pe_check(model, state_dict)
        elif positional_embedding == 'sine':
            state_dict['classifier.positional_emb'] = model.state_dict()['classifier.positional_emb']
        model.load_state_dict(state_dict)
    return model