def __init__(self, backbone1, backbone2, drop, pretrained=True):
        super(MultiModalNet, self).__init__()

        self.visit_model = DPN26()
        if backbone1 == 'se_resnext101_32x4d':
            self.img_encoder = se_resnext101_32x4d(9, None)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnext50_32x4d':
            self.img_encoder = se_resnext50_32x4d(9, None)

            print(
                "load pretrained model from ./pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth"
            )
            state_dict = torch.load(
                './pretrained_seresnet/se_resnext50_32x4d-a260b3a4.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnext26_32x4d':
            self.img_encoder = se_resnext26_32x4d(9, None)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext':
            self.img_encoder = multiscale_se_resnext(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext_cat':
            self.img_encoder = multiscale_se_resnext_cat(9)
            self.img_fc = nn.Linear(1024, 256)

        elif backbone1 == 'multiscale_se_resnext_HR':
            self.img_encoder = multiscale_se_resnext_HR(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnet50':
            self.img_encoder = se_resnet50(9, None)
            print(
                "load pretrained model from ./pretrained_seresnet/se_resnet50-ce0d4300.pth"
            )
            state_dict = torch.load(
                './pretrained_seresnet/se_resnet50-ce0d4300.pth')

            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)

        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(512, 9)
Пример #2
0
def GeResult():
    # Priors
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    # Dataset
    Dataset = TiangongResultMerge(root='data')
    Dataloader = data.DataLoader(Dataset,
                                 1,
                                 num_workers=1,
                                 shuffle=True,
                                 pin_memory=True)

    # Network

    #Network = pnasnet5large(6, None)
    Network = ResNeXt101_64x4d(6)
    net = torch.nn.DataParallel(Network, device_ids=[0])
    cudnn.benchmark = True

    Network.load_state_dict(
        torch.load('weights/aug_ResNeXt/_Tiangong_SGD_85.pth'))
    #net = torch.load('weights/aug_ResNeXt/Tiangong55000use.pth')
    Network.eval()

    Network2 = pnasnet5large(6, None)
    net2 = torch.nn.DataParallel(Network2, device_ids=[0])
    cudnn.benchmark = True

    Network2.load_state_dict(
        torch.load('weights/aug_fix1block_pnasnet/_Tiangong_SGD_85.pth'))
    net2.eval()

    Network3 = se_resnet50(6, None)
    net3 = torch.nn.DataParallel(Network3, device_ids=[0])
    cudnn.benchmark = True

    Network3.load_state_dict(
        torch.load('weights/aug_se_resnet50/_Tiangong_SGD_95.pth'))
    #Network3.load_state_dict(torch.load('weights/ResSample_aug_se_resnet50/_Tiangong_SGD_60.pth'))
    net3.eval()

    filename = 'Rejection_se_resnet50_pnasnet_resnext.csv'
    # Result file preparation
    if os.path.exists(filename):
        os.remove(filename)
    os.mknod(filename)

    f = open(filename, 'w')

    for (imgs, img2, anos) in Dataloader:
        imgs = imgs.cuda()
        pred1 = Network.forward(imgs)
        pred2 = net2.forward(img2)
        pred3 = net3.forward(imgs)
        # eliminate the predictions that has low probality
        '''

        
        
        pred1 = torch.nn.functional.normalize(pred1)
        pred2 = torch.nn.functional.normalize(pred2)
        pred3 = torch.nn.functional.normalize(pred3)
        '''
        pred1 = pred1 + 2 * torch.mul(pred1, torch.le(pred1, 0).float())
        pred2 = pred2 + 2 * torch.mul(pred2, torch.le(pred2, 0).float())
        pred3 = pred3 + 2 * torch.mul(pred3, torch.le(pred3, 0).float())
        #print(pred1)
        #pred1 = torch.mul(pred1, torch.ge(pred1,torch.topk(pred1, dim = 1, k=3, largest = True)[0][2]))
        #pred2 = torch.mul(pred2, torch.ge(pred2,torch.topk(pred2, k=3, largest = True)[0][2]))
        #pred3 = torch.mul(pred3, torch.ge(pred3,torch.topk(pred3, k=3, largest = True)[0][2]))

        preds = torch.add(pred1, pred2)
        preds.add(pred3)
        _, pred = preds.data.topk(1, 1, True, True)
        f.write(anos[0] + ',' + CLASSES[pred[0][0]] + '\r\n')
Пример #3
0
    def __init__(self, backbone1, backbone2, drop, pretrained=True):
        super(MultiModalNet, self).__init__()
        self.visit_model = DPN26()

        if backbone1 == 'se_resnext101_32x4d':
            self.img_encoder = se_resnext101_32x4d(9, None)

            # print("load pretrained model from pth/se_resnext101_32x4d-3b2fe3d8.pth")
            # state_dict = torch.load('pth/se_resnext101_32x4d-3b2fe3d8.pth')
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'densenet169':
            self.img_encoder = densenet169(1000, None)
            self.img_fc = nn.Linear(1000, 256)
        elif backbone1 == 'inceptionv3':
            self.img_encoder = inceptionv3(9, None)

            print("load pretrained model from pth inceptionv3")
            state_dict = torch.load('pth/inception_v3_google-1a9a5a14.pth')
            state_dict.pop('fc.bias')
            state_dict.pop('fc.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(1000, 256)
        elif backbone1 == 'densenet121':
            self.img_encoder = densenet121(9, None)

            print("load pretrained model from pth/densenet121-fbdb23505.pth")
            state_dict = torch.load('pth/densenet121-fbdb23505.pth')
            state_dict.pop('classifier.bias')
            state_dict.pop('classifier.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(1000, 256)
        elif backbone1 == 'senet154':
            self.img_encoder = senet154(9, None)
            # not right
            # print("load pretrained model from pth/senet154-c7b49a05.pth")
            # state_dict = torch.load('pth/senet154-c7b49a05.pth')
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'nasnetalarge':
            self.img_encoder = nasnetalarge(2048, None)

            #not right
            print(
                "load pretrained model from pth/nasnetalarge-a1897284.pth in multimodal.py"
            )
            state_dict = torch.load('pth/nasnetalarge-a1897284.pth')
            #print(state_dict.keys())
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'nasnetamobile':
            self.img_encoder = nasnetamobile(2048, None)
            # not right
            print("load pretrained model from pth nasnetamobile")
            state_dict = torch.load('pth/nasnetamobile-7e03cead.pth')
            # print(state_dict.keys())
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'ResNeXt101_64x4d':
            self.img_encoder = se_resnext101_64x4d(9, None)

            # print("load pretrained model from pth/resnext101_64x4d-e77a0586.pth")
            # state_dict = torch.load('pth/resnext101_64x4d-e77a0586.pth')
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict=False)

            self.img_fc = nn.Linear(2048, 256)
        elif backbone1 == 'se_resnext50_32x4d':
            self.img_encoder = se_resnext50_32x4d(9, None)
            # print("load pretrained model from pth/se_resnext50_32x4d-a260b3a4.pth")
            # state_dict = torch.load('pth/se_resnext50_32x4d-a260b3a4.pth')

            # print("load pretrained model from weights_82/BDXJTU2019_SGD_82.pth")
            # state_dict1 = torch.load('weights_82/BDXJTU2019_SGD_82.pth')
            #
            # key1=state_dict1.keys()
            # dict_img={}
            # dict_vis={}
            # dict_fc={}
            # dict_cls={}
            #
            # key_img=[]
            # key_vis=[]
            # key_fc=[]
            # key_cls=[]
            # for key in key1:
            #     if key.count("img_encoder")>0:
            #         key_img.append(key)
            #     elif key.count("visit_model")>0:
            #         key_vis.append(key)
            #     elif key.count("img_fc") > 0:
            #         key_fc.append(key)
            #     elif key.count("cls") > 0:
            #         key_cls.append(key)
            #     else:
            #         print(key)
            # dict_img.fromkeys(key_img)
            # dict_vis.fromkeys(key_vis)
            # dict_fc.fromkeys(key_fc)
            # dict_cls.fromkeys(key_cls)
            # for key in key1:
            #     if key.count("img_encoder")>0:
            #         dict_img[key]=state_dict1[key]
            #     elif key.count("visit_model")>0:
            #         dict_vis[key]=state_dict1[key]
            #     elif key.count("img_fc")>0:
            #         dict_fc[key]=state_dict1[key]
            #     elif key.count("cls") > 0:
            #         dict_cls[key]=state_dict1[key]
            #     else:
            #         print(key)

            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')

            #self.img_encoder.load_state_dict(dict_img, strict = False)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnext26_32x4d':
            self.img_encoder = se_resnext26_32x4d(9, None)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext':
            self.img_encoder = multiscale_se_resnext(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'multiscale_se_resnext_cat':
            self.img_encoder = multiscale_se_resnext(9)
            self.img_fc = nn.Linear(1024, 256)

        elif backbone1 == 'multiscale_se_resnext_HR':
            self.img_encoder = multiscale_se_resnext_HR(9)
            self.img_fc = nn.Linear(2048, 256)

        elif backbone1 == 'se_resnet50':
            self.img_encoder = se_resnet50(9, None)
            # print("load pretrained model from pth/se_resnet50-ce0d4300.pth")
            # state_dict = torch.load('pth/se_resnet50-ce0d4300.pth')
            #
            # state_dict.pop('last_linear.bias')
            # state_dict.pop('last_linear.weight')
            # self.img_encoder.load_state_dict(state_dict, strict = False)

            self.img_fc = nn.Linear(2048, 256)

        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(512, 9)
Пример #4
0
def main():
    #create model
    best_prec1 = 0
    # Dataset
    #idxs = np.load('SubResampleIndex.npy')
    Dataset_train = H5Dataset(root='data', mode='training', DataType='sen2')

    Dataloader_train = data.DataLoader(Dataset_train,
                                       args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)
    '''
    Dataloader_train = data.DataLoader(Dataset_train, args.batch_size,
                                 num_workers = args.num_workers,
                                 shuffle = True, pin_memory = True)
    '''
    Dataset_validation = H5Dataset(root=args.dataset_root,
                                   mode='validation',
                                   DataType='sen2')
    Dataloader_validation = data.DataLoader(Dataset_validation,
                                            batch_size=1,
                                            num_workers=args.num_workers,
                                            shuffle=True,
                                            pin_memory=True)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)
    if args.basenet == 'ResNeXt':
        model = Sen2ResNeXt(num_classes=args.class_num,
                            depth=11,
                            cardinality=16)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True

    if args.basenet == 'SimpleNetSen2':
        model = SimpleNetSen2(args.class_num)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True

    elif args.basenet == 'se_resnet50':
        model = se_resnet50(args.class_num, None)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True

    elif args.basenet == 'se_resnet50_shallow':
        model = se_resnet50_shallow(args.class_num, None)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True

    elif args.basenet == 'nasnetamobile':
        model = nasnetamobile(args.class_num, None)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True

    elif args.basenet == 'pnasnet':
        model = pnasnet5large(args.class_num, None)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True
        if args.resume:
            model.load_state_dict(torch.load(args.resume))
        else:
            state_dict = torch.load('pnasnet5large-bf079911.pth')
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            model.load_state_dict(state_dict, strict=False)
            init.xavier_uniform_(model.last_linear.weight.data)
            model.last_linear.bias.data.zero_()

    elif args.basenet == 'se_resnet101':
        model = se_resnet101(args.class_num, None)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True
        if args.resume:
            model.load_state_dict(torch.load(args.resume))
        else:
            state_dict = torch.load('se_resnet101-7e38fcc6.pth')
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            model.load_state_dict(state_dict, strict=False)
            init.xavier_uniform_(model.last_linear.weight.data)
            model.last_linear.bias.data.zero_()

    elif args.basenet == 'se_resnext101_32x4d':
        model = se_resnext101_32x4d(args.class_num, None)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
        cudnn.benchmark = True
        if args.resume:
            model.load_state_dict(torch.load(args.resume))
        else:
            state_dict = torch.load('se_resnext101_32x4d-3b2fe3d8.pth')
            state_dict.pop('last_linear.bias')
            state_dict.pop('last_linear.weight')
            model.load_state_dict(state_dict, strict=False)
            init.xavier_uniform_(model.last_linear.weight.data)
            model.last_linear.bias.data.zero_()

    model = model.cuda()
    cudnn.benchmark = True
    '''
    
    weights = [ 9.73934491,  2.02034301,  1.55741015,  5.70558317,  2.99272419,
        1.39866818, 15.09911288,  1.25512384,  3.63361307,  4.12907813,
        1.1505058 ,  5.18803868,  5.38559738,  1.1929091 , 20.63503344,
        6.24955685,  1.        ]
    '''
    #weights = np.load('Precision_CM.npy')
    #weights = 1/torch.diag(torch.FloatTensor(weights))weight = weights
    criterion = nn.CrossEntropyLoss().cuda()
    #criterion = nn.CosineEmbeddingLoss().cuda()
    Optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay,
                          nesterov=True)
    torch.save(
        model.state_dict(),
        'weights/lr8e-3_bs8_sen2_' + args.basenet + '/' + 'LCZ42_SGD' + '.pth')
    for epoch in range(args.start_epoch, args.epochs):

        #adjust_learning_rate(Optimizer, epoch)
        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch,
              Dataloader_validation
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)