Esempio n. 1
0
    def __init__(self, num_class, pretrain=True):
        super(multiscale_se_resnext_cat, self).__init__()

        self.base_model1 = se_resnext50_32x4d(9, None)
        if pretrain == True:
            print(
                "load model1 from ./pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth"
            )
            state_dict = torch.load(
                './pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.base_model1.load_state_dict(state_dict, strict=False)

        self.base_model2 = se_resnext50_32x4d(9, None)
        if pretrain == True:
            print(
                "load model2 from ./pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth"
            )
            state_dict = torch.load(
                './pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.base_model2.load_state_dict(state_dict, strict=False)

        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(4096, 1024)
    def __init__(self, num_class):
        super(multiscale_se_resnext_cat, self).__init__()

        self.base_model1 = se_resnext50_32x4d(9, None)
        self.base_model2 = se_resnext50_32x4d(9, None)

        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(4096, 1024)
    def __init__(self, backbone1, backbone2, drop, pretrained=True):
        super(MultiModalNet, self).__init__()

        self.visit_model = DPN26()
        if backbone1 == 'se_resnext101_32x4d':
            self.img_encoder = se_resnext101_32x4d(9, None)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnext50_32x4d':
            self.img_encoder = se_resnext50_32x4d(9, None)

            print(
                "load pretrained model from ./pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth"
            )
            state_dict = torch.load(
                './pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnext26_32x4d':
            self.img_encoder = se_resnext26_32x4d(9, None)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext':
            self.img_encoder = multiscale_se_resnext(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext_cat':
            self.img_encoder = multiscale_se_resnext_cat(9)
            self.img_fc = nn.Linear(1024, 256)

        elif backbone1 == 'multiscale_se_resnext_HR':
            self.img_encoder = multiscale_se_resnext_HR(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnet50':
            self.img_encoder = se_resnet50(9, None)
            print(
                "load pretrained model from ./pretrained_seresnet/se_resnet50-ce0d4300.pth"
            )
            state_dict = torch.load(
                './pretrained_seresnet/se_resnet50-ce0d4300.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)

        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(512, 9)
Esempio n. 4
0
    def __init__(self, num_class, pretrain=True):
        super(multiscale_se_resnext_HR, self).__init__()

        self.base_model = se_resnext50_32x4d(9, None)

        if pretrain == True:
            print("load model from pth/se_resnext50_32x4d-a260b3a4.pth")
            state_dict = torch.load('pth/se_resnext50_32x4d-a260b3a4.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.base_model.load_state_dict(state_dict, strict=False)

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
Esempio n. 5
0
    def __init__(self,num_class):#, pretrain = True):
        super(multiscale_se_resnext,self).__init__()

        self.base_model = se_resnext50_32x4d(9, None)

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
Esempio n. 6
0
import torch
import time
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.backends.cudnn as cudnn
import torch.nn.init as init
from torch.autograd import Variable
import math
from basenet.senet import se_resnet50, se_resnext101_32x4d, se_resnext50_32x4d

model = se_resnext50_32x4d(9, None)

state_dict = torch.load('se_resnext50_32x4d-a260b3a4.pth')
for k, v in state_dict.items():
    print(k)
#print(state_dict)

state_dict.pop('last_linear.bias')
state_dict.pop('last_linear.weight')
model.load_state_dict(state_dict, strict=False)

print(model(torch.randn(16, 3, 100, 100).float()).size())
"""
init.xavier_uniform_(model.last_linear.weight.data)
model.last_linear.bias.data.zero_()
"""
Esempio n. 7
0
    def __init__(self, backbone1, backbone2, drop, pretrained=True):
        super(MultiModalNet, self).__init__()
        self.visit_model = DPN26()

        if backbone1 == 'se_resnext101_32x4d':
            self.img_encoder = se_resnext101_32x4d(9, None)

            # print("load pretrained model from pth/se_resnext101_32x4d-3b2fe3d8.pth")
            # state_dict = torch.load('pth/se_resnext101_32x4d-3b2fe3d8.pth')
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'densenet169':
            self.img_encoder = densenet169(1000, None)
            self.img_fc = nn.Linear(1000, 256)
        elif backbone1 == 'inceptionv3':
            self.img_encoder = inceptionv3(9, None)

            print("load pretrained model from pth inceptionv3")
            state_dict = torch.load('pth/inception_v3_google-1a9a5a14.pth')
            state_dict.pop('fc.bias')
            state_dict.pop('fc.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(1000, 256)
        elif backbone1 == 'densenet121':
            self.img_encoder = densenet121(9, None)

            print("load pretrained model from pth/densenet121-fbdb23505.pth")
            state_dict = torch.load('pth/densenet121-fbdb23505.pth')
            state_dict.pop('classifier.bias')
            state_dict.pop('classifier.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(1000, 256)
        elif backbone1 == 'senet154':
            self.img_encoder = senet154(9, None)
            # not right
            # print("load pretrained model from pth/senet154-c7b49a05.pth")
            # state_dict = torch.load('pth/senet154-c7b49a05.pth')
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'nasnetalarge':
            self.img_encoder = nasnetalarge(2048, None)

            #not right
            print(
                "load pretrained model from pth/nasnetalarge-a1897284.pth in multimodal.py"
            )
            state_dict = torch.load('pth/nasnetalarge-a1897284.pth')
            #print(state_dict.keys())
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'nasnetamobile':
            self.img_encoder = nasnetamobile(2048, None)
            # not right
            print("load pretrained model from pth nasnetamobile")
            state_dict = torch.load('pth/nasnetamobile-7e03cead.pth')
            # print(state_dict.keys())
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'ResNeXt101_64x4d':
            self.img_encoder = se_resnext101_64x4d(9, None)

            # print("load pretrained model from pth/resnext101_64x4d-e77a0586.pth")
            # state_dict = torch.load('pth/resnext101_64x4d-e77a0586.pth')
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'se_resnext50_32x4d':
            self.img_encoder = se_resnext50_32x4d(9, None)
            # print("load pretrained model from pth/se_resnext50_32x4d-a260b3a4.pth")
            # state_dict = torch.load('pth/se_resnext50_32x4d-a260b3a4.pth')

            # print("load pretrained model from weights_82/BDXJTU2019_SGD_82.pth")
            # state_dict1 = torch.load('weights_82/BDXJTU2019_SGD_82.pth')
            #
            # key1=state_dict1.keys()
            # dict_img={}
            # dict_vis={}
            # dict_fc={}
            # dict_cls={}
            #
            # key_img=[]
            # key_vis=[]
            # key_fc=[]
            # key_cls=[]
            # for key in key1:
            #     if key.count("img_encoder")>0:
            #         key_img.append(key)
            #     elif key.count("visit_model")>0:
            #         key_vis.append(key)
            #     elif key.count("img_fc") > 0:
            #         key_fc.append(key)
            #     elif key.count("cls") > 0:
            #         key_cls.append(key)
            #     else:
            #         print(key)
            # dict_img.fromkeys(key_img)
            # dict_vis.fromkeys(key_vis)
            # dict_fc.fromkeys(key_fc)
            # dict_cls.fromkeys(key_cls)
            # for key in key1:
            #     if key.count("img_encoder")>0:
            #         dict_img[key]=state_dict1[key]
            #     elif key.count("visit_model")>0:
            #         dict_vis[key]=state_dict1[key]
            #     elif key.count("img_fc")>0:
            #         dict_fc[key]=state_dict1[key]
            #     elif key.count("cls") > 0:
            #         dict_cls[key]=state_dict1[key]
            #     else:
            #         print(key)

            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')

            #self.img_encoder.load_state_dict(dict_img, strict = False)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnext26_32x4d':
            self.img_encoder = se_resnext26_32x4d(9, None)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext':
            self.img_encoder = multiscale_se_resnext(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext_cat':
            self.img_encoder = multiscale_se_resnext(9)
            self.img_fc = nn.Linear(1024, 256)

        elif backbone1 == 'multiscale_se_resnext_HR':
            self.img_encoder = multiscale_se_resnext_HR(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnet50':
            self.img_encoder = se_resnet50(9, None)
            # print("load pretrained model from pth/se_resnet50-ce0d4300.pth")
            # state_dict = torch.load('pth/se_resnet50-ce0d4300.pth')
            #
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict = False)

            self.img_fc = nn.Linear(2048, 256)

        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(512, 9)