コード例 #1
0
def GeResult():
    # Dataset
    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='1_val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=1,
                                     num_workers=4,
                                     shuffle=True,
                                     pin_memory=True)

    class_names = [
        '001', '002', '003', '004', '005', '006', '007', '008', '009'
    ]

    # construct network
    epoch = 12
    net = MultiModalNet('se_resnet152', 'DPN26', 0.5)
    #    if torch.cuda.device_count() > 1:
    #        print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    #        net = nn.DataParallel(net)
    net.to(device)
    #    net.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_' + str(epoch) + '.pth'))
    net.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/inception_005.pth'
        ))
    print('load ' + str(epoch) + ' epoch model')
    net.eval()

    results = []
    results_anno = []

    for i, (Input_img, Input_vis, Anno) in enumerate(Dataloader_val):
        Input_img = Input_img.to(device)
        Input_vis = Input_vis.to(device)

        ConfTensor = net.forward(Input_img, Input_vis)
        _, pred = ConfTensor.data.topk(1, 1, True, False)

        results.append(pred.item())

        results_anno.append(Anno)  #append annotation results
        if ((i + 1) % 1000 == 0):
            print(i + 1)
            print('Accuracy of Orignal Input: %0.6f' %
                  (accuracy_score(results, results_anno, normalize=True)))

    # print accuracy of different input
    print('Accuracy of Orignal Input: %0.6f' %
          (accuracy_score(results, results_anno, normalize=True)))

    cnf_matrix = confusion_matrix(results_anno, results)
    cnf_tr = np.trace(cnf_matrix)

    cnf_tr = cnf_tr.astype('float')
    print(cnf_tr / len(Dataset_val))
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          title='Confusion matrix, without normalization')
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')
    plt.show()
コード例 #2
0
def main():
    #create model
    best_prec1 = 0

    if args.basenet == 'se_resnet152':
        model = MultiModalNet('se_resnet152', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    elif args.basenet == 'se_resnext50_32x4d':
        model = MultiModalNet1('se_resnext50_32x4d', 'DPN26', 0.5)
    elif args.basenet == 'se_resnet50':
        model = MultiModalNet1('se_resnet50', 'DPN26', 0.5)
    elif args.basenet == 'densenet201':
        model = MultiModalNet2('densenet201', 'DPN26', 0.5)

    elif args.basenet == 'oct_resnet101':
        model = oct_resnet101()
#    print("load pretrained model from /home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_4.pth")
#    pre='/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_1.pth'
#    model.load_state_dict(torch.load(pre))
#net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
    model.to(device)
    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                  mode='MM_1_train',
                                  transform=Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train,
                                       128,
                                       num_workers=4,
                                       shuffle=True,
                                       pin_memory=True)

    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=32,
                                     num_workers=4,
                                     shuffle=True,
                                     pin_memory=True)

    #    criterion = nn.CrossEntropyLoss(weight = weights).cuda()
    criterion = nn.CrossEntropyLoss().to(device)
    #    Optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = args.lr, momentum = args.momentum,
    #                          weight_decay = args.weight_decay)
    Optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch % 1 == 0:
            torch.save(
                model.module.state_dict(), 'weights/' + args.basenet +
                '_se_resnext50_32x4d_resample_pretrained_80w_1/' +
                'BDXJTU2019_SGD_' + repr(epoch) + '.pth')
コード例 #3
0
def GeResult():

    # Dataset
    Dataset = BDXJTU2019_test(root='/home/dell/Desktop/2019BaiduXJTU/data')
    Dataloader = data.DataLoader(Dataset,
                                 1,
                                 num_workers=1,
                                 shuffle=False,
                                 pin_memory=True)
    net1 = MultiModalNet1('se_resnet50', 'DPN26', 0.5)
    net1.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet50_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_9.pth'
        ))
    net1.to(device)
    net1.eval()
    net2 = MultiModalNet('se_resnet152', 'DPN26', 0.5)
    net2.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_6.pth'
        ))
    net2.to(device)
    net2.eval()
    net3 = MultiModalNet2('densenet201', 'DPN26', 0.5)
    net3.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_3.pth'
        ))
    net3.to(device)
    net3.eval()
    net4 = MultiModalNet2('densenet201', 'DPN26', 0.5)
    net4.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_10.pth'
        ))
    net4.to(device)
    net4.eval()
    net5 = MultiModalNet1('multiscale_se_resnext', 'DPN26', 0.5)
    net5.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/multiscale_se_resnext_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_11.pth'
        ))
    net5.to(device)
    net5.eval()
    net6 = MultiModalNet1('multiscale_resnet', 'DPN26', 0.5)
    net6.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/multiscale_resnet_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_10.pth'
        ))
    net6.to(device)
    net6.eval()
    net7 = MultiModalNet2('densenet201', 'DPN26', 0.5)
    net7.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_4.pth'
        ))
    net7.to(device)
    net7.eval()
    #Network = pnasnet5large(6, None)
    #Network = ResNeXt101_64x4d(6)
    #    net1 =MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net1.load_state_dict(torch.load('/home/zxw/2019BaiduXJTU/weights/MultiModal_se_resnext50_32x4d_resample_pretrained/BDXJTU2019_SGD_16.pth'))
    #    net1.eval()

    #    net2 = MultiModalNet('multiscale_se_resnext_HR', 'DPN26', 0.5)
    #    net2.load_state_dict(torch.load('/home/zxw/2019BaiduXJTU/weights/MultiModal_50_MS_resample_pretrained_HR/BDXJTU2019_SGD_26.pth'))
    #    net2.eval()

    #    net3 = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net3.load_state_dict(torch.load('/home/zxw/2019BaiduXJTU/weights/MultiModal_se_resnext50_32x4d_resample_pretrained_w/BDXJTU2019_SGD_50.pth'))
    #    net3.eval()

    #    net4 = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net4.load_state_dict(torch.load('/home/zxw/2019BaiduXJTU/weights/MultiModal_se_resnext50_32x4d_resample_pretrained_1/BDXJTU2019_SGD_80.pth'))
    #    net4.eval()

    filename = 'MM_ensemble4_TTA.txt'

    f = open(filename, 'w')

    for (Input_O, Input_H, visit_tensor, anos) in Dataloader:
        ConfTensor_O = net1.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        ConfTensor_H = net2.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        ConfTensor_V = net3.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        ConfTensor_V0 = net3.forward(Input_H.to(device),
                                     visit_tensor.to(device))
        ConfTensor_1 = net4.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        ConfTensor_10 = net4.forward(Input_H.to(device),
                                     visit_tensor.to(device))
        ConfTensor_2 = net5.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        ConfTensor_20 = net5.forward(Input_H.to(device),
                                     visit_tensor.to(device))
        ConfTensor_3 = net6.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        ConfTensor_4 = net7.forward(Input_O.to(device),
                                    visit_tensor.to(device))
        preds = torch.nn.functional.normalize(
            ConfTensor_O) + torch.nn.functional.normalize(
                ConfTensor_H) + 2 * torch.nn.functional.normalize(
                    ConfTensor_V) + torch.nn.functional.normalize(
                        ConfTensor_V0) + torch.nn.functional.normalize(
                            ConfTensor_1
                        ) + torch.nn.functional.normalize(
                            ConfTensor_10
                        ) + 2 * torch.nn.functional.normalize(
                            ConfTensor_2) + torch.nn.functional.normalize(
                                ConfTensor_20
                            ) + torch.nn.functional.normalize(
                                ConfTensor_3
                            ) + 2 * torch.nn.functional.normalize(ConfTensor_4)
        _, pred = preds.data.topk(1, 1, True, True)
        #f.write(anos[0] + ',' + CLASSES[4] + '\r\n')
        print(anos[0][:-4] + '\t' + CLASSES[pred[0][0]] + '\n')
        f.writelines(anos[0][:-4] + '\t' + CLASSES[pred[0][0]] + '\n')
    f.close()
コード例 #4
0
def GeResult():

    # Dataset
    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=1,
                                     num_workers=2,
                                     shuffle=True,
                                     pin_memory=True)

    class_names = [
        '001', '002', '003', '004', '005', '006', '007', '008', '009'
    ]

    net1 = MultiModalNet1('se_resnet50', 'DPN26', 0.5)
    net1.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet50_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_9.pth'
        ))
    net1.to(device)
    net1.eval()
    net2 = MultiModalNet('se_resnet152', 'DPN26', 0.5)
    net2.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/se_resnet152_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_4.pth'
        ))
    net2.to(device)
    net2.eval()
    net3 = MultiModalNet2('densenet201', 'DPN26', 0.5)
    net3.load_state_dict(
        torch.load(
            '/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_3.pth'
        ))
    net3.to(device)
    net3.eval()
    # construct network
    #    net1 =MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net1.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/models/BDXJTU2019_SGD_16.pth'))
    #    net1.eval()

    #    net2 = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net2.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/models/BDXJTU2019_SGD_26.pth'))
    #    net2.eval()

    #    net3 =MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
    #    net3.load_state_dict(torch.load('/home/dell/Desktop/2019BaiduXJTU/models/BDXJTU2019_SGD_50.pth'))
    #    net3.eval()

    results = []
    results_anno = []

    for i, (Input_img, Input_vis, Anno) in enumerate(Dataloader_val):
        Input_img = Input_img.to(device)
        Input_vis = Input_vis.to(device)

        ConfTensor1 = net1.forward(Input_img, Input_vis)
        ConfTensor2 = net2.forward(Input_img, Input_vis)
        ConfTensor3 = net3.forward(Input_img, Input_vis)

        ConfTensor = (torch.nn.functional.normalize(ConfTensor1) +
                      torch.nn.functional.normalize(ConfTensor2) +
                      torch.nn.functional.normalize(ConfTensor3)) / 3

        score, pred = ConfTensor.data.topk(1, 1, True, False)
        #print(score.item())
        if (score.item() > 0.85):
            results.append(pred.item())

            results_anno.append(Anno)  #append annotation results
        if ((i + 1) % 2000 == 0):
            print(i + 1)
            print(len(results))
            print('Accuracy of Orignal Input: %0.6f' %
                  (accuracy_score(results, results_anno, normalize=True)))
    # print accuracy of different input
    print('Accuracy of Orignal Input: %0.6f' %
          (accuracy_score(results, results_anno, normalize=True)))

    cnf_matrix = confusion_matrix(results_anno, results)
    cnf_tr = np.trace(cnf_matrix)
    cnf_tr = cnf_tr.astype('float')
    print(cnf_tr / len(Dataset_val))
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          title='Confusion matrix, without normalization')
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')
    plt.show()