def __init__(self, backbone1, backbone2, drop, pretrained=True): super(MultiModalNet, self).__init__() self.visit_model = DPN26() if backbone1 == 'se_resnext101_32x4d': self.img_encoder = se_resnext101_32x4d(9, None) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'se_resnext50_32x4d': self.img_encoder = se_resnext50_32x4d(9, None) print( "load pretrained model from ./pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth" ) state_dict = torch.load( './pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth') state_dict.pop('last_linear.bias') state_dict.pop('last_linear.weight') self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'se_resnext26_32x4d': self.img_encoder = se_resnext26_32x4d(9, None) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'multiscale_se_resnext': self.img_encoder = multiscale_se_resnext(9) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'multiscale_se_resnext_cat': self.img_encoder = multiscale_se_resnext_cat(9) self.img_fc = nn.Linear(1024, 256) elif backbone1 == 'multiscale_se_resnext_HR': self.img_encoder = multiscale_se_resnext_HR(9) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'se_resnet50': self.img_encoder = se_resnet50(9, None) print( "load pretrained model from ./pretrained_seresnet/se_resnet50-ce0d4300.pth" ) state_dict = torch.load( './pretrained_seresnet/se_resnet50-ce0d4300.pth') state_dict.pop('last_linear.bias') state_dict.pop('last_linear.weight') self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) self.dropout = nn.Dropout(0.5) self.cls = nn.Linear(512, 9)
def __init__(self, backbone1, backbone2, drop, pretrained=True): super(MultiModalNet, self).__init__() self.visit_model = DPN26() if backbone1 == 'se_resnext101_32x4d': self.img_encoder = se_resnext101_32x4d(9, None) # print("load pretrained model from pth/se_resnext101_32x4d-3b2fe3d8.pth") # state_dict = torch.load('pth/se_resnext101_32x4d-3b2fe3d8.pth') # state_dict.pop('last_linear.bias') # state_dict.pop('last_linear.weight') # self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'densenet169': self.img_encoder = densenet169(1000, None) self.img_fc = nn.Linear(1000, 256) elif backbone1 == 'inceptionv3': self.img_encoder = inceptionv3(9, None) print("load pretrained model from pth inceptionv3") state_dict = torch.load('pth/inception_v3_google-1a9a5a14.pth') state_dict.pop('fc.bias') state_dict.pop('fc.weight') self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(1000, 256) elif backbone1 == 'densenet121': self.img_encoder = densenet121(9, None) print("load pretrained model from pth/densenet121-fbdb23505.pth") state_dict = torch.load('pth/densenet121-fbdb23505.pth') state_dict.pop('classifier.bias') state_dict.pop('classifier.weight') self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(1000, 256) elif backbone1 == 'senet154': self.img_encoder = senet154(9, None) # not right # print("load pretrained model from pth/senet154-c7b49a05.pth") # state_dict = torch.load('pth/senet154-c7b49a05.pth') # state_dict.pop('last_linear.bias') # state_dict.pop('last_linear.weight') # self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'nasnetalarge': self.img_encoder = nasnetalarge(2048, None) #not right print( "load pretrained model from pth/nasnetalarge-a1897284.pth in multimodal.py" ) state_dict = torch.load('pth/nasnetalarge-a1897284.pth') #print(state_dict.keys()) state_dict.pop('last_linear.bias') state_dict.pop('last_linear.weight') self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'nasnetamobile': self.img_encoder = nasnetamobile(2048, None) # not right print("load pretrained model from pth nasnetamobile") state_dict = torch.load('pth/nasnetamobile-7e03cead.pth') # print(state_dict.keys()) state_dict.pop('last_linear.bias') state_dict.pop('last_linear.weight') self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'ResNeXt101_64x4d': self.img_encoder = se_resnext101_64x4d(9, None) # print("load pretrained model from pth/resnext101_64x4d-e77a0586.pth") # state_dict = torch.load('pth/resnext101_64x4d-e77a0586.pth') # state_dict.pop('last_linear.bias') # state_dict.pop('last_linear.weight') # self.img_encoder.load_state_dict(state_dict, strict=False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'se_resnext50_32x4d': self.img_encoder = se_resnext50_32x4d(9, None) # print("load pretrained model from pth/se_resnext50_32x4d-a260b3a4.pth") # state_dict = torch.load('pth/se_resnext50_32x4d-a260b3a4.pth') # print("load pretrained model from weights_82/BDXJTU2019_SGD_82.pth") # state_dict1 = torch.load('weights_82/BDXJTU2019_SGD_82.pth') # # key1=state_dict1.keys() # dict_img={} # dict_vis={} # dict_fc={} # dict_cls={} # # key_img=[] # key_vis=[] # key_fc=[] # key_cls=[] # for key in key1: # if key.count("img_encoder")>0: # key_img.append(key) # elif key.count("visit_model")>0: # key_vis.append(key) # elif key.count("img_fc") > 0: # key_fc.append(key) # elif key.count("cls") > 0: # key_cls.append(key) # else: # print(key) # dict_img.fromkeys(key_img) # dict_vis.fromkeys(key_vis) # dict_fc.fromkeys(key_fc) # dict_cls.fromkeys(key_cls) # for key in key1: # if key.count("img_encoder")>0: # dict_img[key]=state_dict1[key] # elif key.count("visit_model")>0: # dict_vis[key]=state_dict1[key] # elif key.count("img_fc")>0: # dict_fc[key]=state_dict1[key] # elif key.count("cls") > 0: # dict_cls[key]=state_dict1[key] # else: # print(key) # state_dict.pop('last_linear.bias') # state_dict.pop('last_linear.weight') #self.img_encoder.load_state_dict(dict_img, strict = False) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'se_resnext26_32x4d': self.img_encoder = se_resnext26_32x4d(9, None) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'multiscale_se_resnext': self.img_encoder = multiscale_se_resnext(9) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'multiscale_se_resnext_cat': self.img_encoder = multiscale_se_resnext(9) self.img_fc = nn.Linear(1024, 256) elif backbone1 == 'multiscale_se_resnext_HR': self.img_encoder = multiscale_se_resnext_HR(9) self.img_fc = nn.Linear(2048, 256) elif backbone1 == 'se_resnet50': self.img_encoder = se_resnet50(9, None) # print("load pretrained model from pth/se_resnet50-ce0d4300.pth") # state_dict = torch.load('pth/se_resnet50-ce0d4300.pth') # # state_dict.pop('last_linear.bias') # state_dict.pop('last_linear.weight') # self.img_encoder.load_state_dict(state_dict, strict = False) self.img_fc = nn.Linear(2048, 256) self.dropout = nn.Dropout(0.5) self.cls = nn.Linear(512, 9)