Esempio n. 1
0
    def __init__(self,
                 subtype='wide_resnet50_2',
                 out_stages=[2, 3, 4],
                 backbone_path=None):
        super(WideResNet, self).__init__()
        self.out_stages = out_stages
        self.backbone_path = backbone_path

        if subtype == 'wide_resnet50_2':
            backbone = resnext50_32x4d(pretrained=not self.backbone_path)
            self.out_channels = [64, 256, 512, 1024, 2048]
        elif subtype == 'wide_resnet101_2':
            backbone = resnext101_32x8d(pretrained=not self.backbone_path)
            self.out_channels = [64, 256, 512, 1024, 2048]
        else:
            raise NotImplementedError

        self.out_channels = self.out_channels[self.out_stages[0]:self.
                                              out_stages[-1] + 1]

        self.conv1 = nn.Sequential(*list(backbone.children())[0:3])
        self.maxpool = nn.Sequential(list(backbone.children())[3])
        self.layer1 = nn.Sequential(list(backbone.children())[4])
        self.layer2 = nn.Sequential(list(backbone.children())[5])
        self.layer3 = nn.Sequential(list(backbone.children())[6])
        self.layer4 = nn.Sequential(list(backbone.children())[7])

        if self.backbone_path:
            self.backbone.load_state_dict(torch.load(self.backbone_path))
        else:
            self.init_weights()
Esempio n. 2
0
    def __init__(self,
                 arch="resnet18",
                 num_classes=1000,
                 torchvision_pretrained=False,
                 pretrained_num_classes=1000,
                 fix_bn=False,
                 partial_bn=False,
                 zero_init_residual=False):
        super(TorchvisionResNet, self).__init__()

        self.num_classes = num_classes
        self.fix_bn = fix_bn
        self.partial_bn = partial_bn

        if arch == 'resnet18':
            self.model = resnet18(pretrained=torchvision_pretrained,
                                  num_classes=pretrained_num_classes,
                                  zero_init_residual=zero_init_residual)
        elif arch == 'resnet50':
            self.model = resnet50(pretrained=torchvision_pretrained,
                                  num_classes=pretrained_num_classes,
                                  zero_init_residual=zero_init_residual)
        elif arch == 'resnext50_32x4d':
            self.model = resnext50_32x4d(pretrained=torchvision_pretrained,
                                         num_classes=pretrained_num_classes,
                                         zero_init_residual=zero_init_residual)
        else:
            raise ValueError('no such value')

        self.init_weights(num_classes, pretrained_num_classes)
Esempio n. 3
0
def get_model(weight):
    weight = str(weight)
    if "101" in weight:
        model = resnext101_32x8d(num_classes=54)
    else:
        model = resnext50_32x4d(num_classes=54)
    model = model.cuda()

    print(weight)
    checkpoint = torch.load(weight, map_location="cpu")
    if "state_dict" in checkpoint.keys():
        state_dict = checkpoint["state_dict"]
        sanitized = {}
        for k, v in state_dict.items():
            # if '101' in weight:
            #     sanitized[k.replace("model.model", "model")] = v
            # else:
            sanitized[k.replace("model.", "")] = v
            sanitized[k.replace("model.last_linear.", "fc.")] = v

        model.load_state_dict(sanitized, strict=False)
        if "101" in weight:
            save_name = "../ilya/assets/resnext101_w1_epoch4.pth"
        else:
            save_name = "../ilya/assets/resnext50_w8_epoch0.pth"
        torch.save(model.state_dict(), save_name)

    else:
        model.load_state_dict(checkpoint)

    del checkpoint
    model.eval()
    return model
Esempio n. 4
0
 def __init__(self, mode_name='wide_resnet50_2'):
     super(Backbone, self).__init__()
     assert mode_name in ('resnet18', 'resnet34', 'resnet50', 'resnet101',
                          'resnet152', 'resnext50_32x4d',
                          'resnext101_32x8d', 'wide_resnet50_2',
                          'wide_resnet101_2')
     if mode_name == 'resnet18':
         self.res_back_bone = resnet.resnet18()
     elif mode_name == 'resnet34':
         self.res_back_bone = resnet.resnet34()
     elif mode_name == 'resnet50':
         self.res_back_bone = resnet.resnet50()
     elif mode_name == 'resnet101':
         self.res_back_bone = resnet.resnet101()
     elif mode_name == 'resnet152':
         self.res_back_bone = resnet.resnet152()
     elif mode_name == 'resnext50_32x4d':
         self.res_back_bone = resnet.resnext50_32x4d()
     elif mode_name == 'resnext101_32x8d':
         self.res_back_bone = resnet.resnext101_32x8d()
     elif mode_name == 'wide_resnet50_2':
         self.res_back_bone = resnet.wide_resnet50_2()
     else:
         self.res_back_bone = resnet.wide_resnet101_2()
     self.backbone = nn.Module()
     layer0 = nn.Sequential(*list(self.res_back_bone.children())[:4])
     self.backbone.add_module('layer0', layer0)
     self.backbone.add_module('layer1', self.res_back_bone.layer1)
     self.backbone.add_module('layer2', self.res_back_bone.layer2)
     self.backbone.add_module('layer3', self.res_back_bone.layer3)
     self.backbone.add_module('layer4', self.res_back_bone.layer4)
Esempio n. 5
0
    def __init__(self,
                 subtype='resnext50_32x4d',
                 out_stages=[2, 3, 4],
                 output_stride=32,
                 backbone_path=None):
        super(ResNeXt, self).__init__()
        self.out_stages = out_stages
        self.output_stride = output_stride  # 8, 16, 32
        self.backbone_path = backbone_path

        if subtype == 'resnext50_32x4d':
            backbone = resnext50_32x4d(pretrained=not self.backbone_path)
            self.out_channels = [64, 256, 512, 1024, 2048]
        elif subtype == 'resnext101_32x8d':
            backbone = resnext101_32x8d(pretrained=not self.backbone_path)
            self.out_channels = [64, 256, 512, 1024, 2048]
        else:
            raise NotImplementedError

        self.out_channels = self.out_channels[self.out_stages[0]:self.
                                              out_stages[-1] + 1]

        self.conv1 = nn.Sequential(*list(backbone.children())[:4])
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        if self.output_stride == 16:
            s3, s4, d3, d4 = (2, 1, 1, 2)
        elif self.output_stride == 8:
            s3, s4, d3, d4 = (1, 1, 2, 4)

            for n, m in self.layer3.named_modules():
                if 'conv1' in n and (subtype == 'resnet34'
                                     or subtype == 'resnet18'):
                    m.dilation, m.padding, m.stride = (d3, d3), (d3, d3), (s3,
                                                                           s3)
                elif 'conv2' in n:
                    m.dilation, m.padding, m.stride = (d3, d3), (d3, d3), (s3,
                                                                           s3)
                elif 'downsample.0' in n:
                    m.stride = (s3, s3)

        if self.output_stride == 8 or self.output_stride == 16:
            for n, m in self.layer4.named_modules():
                if 'conv1' in n and (subtype == 'resnet34'
                                     or subtype == 'resnet18'):
                    m.dilation, m.padding, m.stride = (d4, d4), (d4, d4), (s4,
                                                                           s4)
                elif 'conv2' in n:
                    m.dilation, m.padding, m.stride = (d4, d4), (d4, d4), (s4,
                                                                           s4)
                elif 'downsample.0' in n:
                    m.stride = (s4, s4)

        if self.backbone_path:
            self.backbone.load_state_dict(torch.load(self.backbone_path))
        else:
            self.init_weights()
Esempio n. 6
0
def get_arch(model_name, n_classes=3, pretrained=False):
    '''
    Classification options are 'resnet18', 'resnet_18_from_cifar', 'resnet50', 'resnet50_from_cifar',
    'resnext50', 'resnext101'; pretrained=False/True
    '''
    mean, std = None, None  # these will only not be None when pretrained==True

    if model_name == 'resnet18':
        model = resnet_imagenet.resnet18(pretrained=pretrained)
        if pretrained: mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, n_classes)

    elif model_name == 'resnet50':
        model = resnet_imagenet.resnet50(pretrained=pretrained)
        if pretrained: mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, n_classes)

    elif model_name == 'resnet50_sws':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnet50_swsl')
        if pretrained: mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        else: mean, std = [0.4310, 0.3012, 0.2162], [0.2748, 0.2021, 0.1691]
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, n_classes)

    elif model_name == 'resnext50':
        model = resnet_imagenet.resnext50_32x4d(pretrained=pretrained)
        if pretrained: mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, n_classes)

    elif model_name == 'resnext50_sws':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnext50_32x4d_swsl')
        if pretrained: mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        else: mean, std = [0.4310, 0.3012, 0.2162], [0.2748, 0.2021, 0.1691]
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, n_classes)

    elif model_name == 'resnext101':
        model = resnet_imagenet.resnext101_32x8d(pretrained=pretrained)
        if pretrained: mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, n_classes)
    else:
        sys.exit('not a valid model_name, check models.get_model.py')

    return model, mean, std
    def __init__(self, extractor_model=resnext50_32x4d(False)):
        super(RGB2RGBAutoencoder, self).__init__()

        self.encoder = torch.nn.Sequential(
            *(list(extractor_model.children())[:-2]))

        self.decoder = torch.nn.Sequential(
            nn.ConvTranspose2d(2048, 512, 5, stride=1, padding=0, bias=False),
            nn.ConvTranspose2d(512, 256, 5, stride=2, padding=0, bias=False),
            nn.ConvTranspose2d(256, 128, 5, stride=2, padding=0, bias=False),
            nn.ConvTranspose2d(128, 64, 5, stride=2, padding=0, bias=False),
            nn.ConvTranspose2d(64, 32, 5, stride=2, padding=0, bias=False),
            nn.ConvTranspose2d(32, 3, 4, stride=1, padding=0, bias=False),
        )
Esempio n. 8
0
def _make_backbone(name='r50', inplanes_backbone=256, pretrained=False):

    if name in ('r50', 'resnet50'):
        backbone = resnet.resnet50(pretrained=pretrained)
    elif name in ('r101', 'resnet101'):
        backbone = resnet.resnet101(pretrained=pretrained)
    elif name in ('r152', 'resnet152'):
        backbone = resnet.resnet152(pretrained=pretrained)
    elif name in ('rx50', 'resnext50_32x4d'):
        backbone = resnet.resnext50_32x4d(pretrained=pretrained)
    elif name in ('rx101', 'resnext101_32x8d'):
        backbone = resnet.resnext101_32x8d(pretrained=pretrained)
    # elif name in ('r50d', 'gluon_resnet50_v1d'):
    #     net = timm.create_model('gluon_resnet50_v1d', pretrained=pretrained)
    #     convert_to_inplace_relu(net)
    # elif name in ('r101d', 'gluon_resnet101_v1d'):
    #     net = timm.create_model('gluon_resnet101_v1d', pretrained=pretrained)
    #     convert_to_inplace_relu(net)
    else:
        inplanes_backbone //= 4
        if name in ('r18', 'resnet18'):
            backbone = resnet.resnet18(pretrained=pretrained)
        elif name in ('r34', 'resnet34'):
            backbone = resnet.resnet34(pretrained=pretrained)
        else:
            assert False, "No BackBone: {}".format(name)

    encoder0 = torch.nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu,
                                   backbone.maxpool)
    encoder1 = backbone.layer1
    encoder2 = backbone.layer2
    encoder3 = backbone.layer3
    encoder4 = backbone.layer4
    fpn = FPN(inplanes=inplanes_backbone)

    return encoder0, encoder1, encoder2, encoder3, encoder4, fpn, inplanes_backbone * (
        2**3)
Esempio n. 9
0
def loss_selector(loss_net):
    #base
    if loss_net == "vgg16":
        from torchvision.models.vgg import vgg16
        net = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(net.features)[:31]).eval()
        return loss_network
    elif loss_net == "vgg16_bn":
        from torchvision.models.vgg import vgg16_bn
        net = vgg16_bn(pretrained=True)
        loss_network = nn.Sequential(*list(net.features)[:44]).eval()
        return loss_network
    elif loss_net == "resnet50":
        from torchvision.models.resnet import resnet50
        net=resnet50(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network
    elif loss_net == "resnet101":
        from torchvision.models.resnet import resnet101
        net=resnet101(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network
    elif loss_net == "resnet152":
        from torchvision.models.resnet import resnet152
        net=resnet152(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network
    elif loss_net == "squeezenet1_1":
        from torchvision.models.squeezenet import squeezenet1_1
        net=squeezenet1_1(pretrained=True)
        classifier=[item for item in net.classifier.modules()][1:-1]
        loss_network=nn.Sequential(*[net.features,*classifier]).eval()
        return loss_network
    elif loss_net == "densenet121":
        from torchvision.models.densenet import densenet121
        net=densenet121(pretrained=True)
        loss_network=nn.Sequential(*[net.features,nn.ReLU()]).eval()
        return loss_network
    elif loss_net == "densenet169":
        from torchvision.models.densenet import densenet169
        net=densenet169(pretrained=True)
        loss_network=nn.Sequential(*[net.features,nn.ReLU()]).eval()
        return loss_network
    elif loss_net == "densenet201":
        from torchvision.models.densenet import densenet201
        net=densenet201(pretrained=True)
        loss_network=nn.Sequential(*[net.features,nn.ReLU()]).eval()
        return loss_network        
    elif loss_net == "mobilenet_v2":
        from torchvision.models.mobilenet import mobilenet_v2
        net=mobilenet_v2(pretrained=True)
        loss_network=nn.Sequential(*[net.features]).eval()
        return loss_network                
    elif loss_net == "resnext50_32x4d":
        from torchvision.models.resnet import resnext50_32x4d
        net=resnext50_32x4d(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network      
    elif loss_net == "resnext101_32x8d":
        from torchvision.models.resnet import resnext101_32x8d
        net=resnext101_32x8d(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network
    elif loss_net == "wide_resnet50_2":
        from torchvision.models.resnet import wide_resnet50_2
        net=wide_resnet50_2(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network
    elif loss_net == "wide_resnet101_2":
        from torchvision.models.resnet import wide_resnet101_2
        net=wide_resnet101_2(pretrained=True)
        loss_network=nn.Sequential(*[child_module for child_module in net.children()][:-2]).eval()
        return loss_network
    elif loss_net == "inception_v3":