コード例 #1
0
def main():
    #create model
    best_prec1 = 0
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)
    model = 0
    if args.basenet == 'MultiModal':
        model = MultiModalNet('se_resnet50', 'DPN26', 0.5)
        #model = MultiModalNet('se_resnet50', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    elif  args.basenet == 'oct_resnet101':
        model = oct_resnet101()
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    assert not model==0
    model = model.cuda()
    cudnn.benchmark = True
    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root = args.dataset_root, mode = 'MM_cleaned_train', transform = Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train, args.batch_size, 
                                 num_workers = args.num_workers,
                                 shuffle = True, pin_memory = True)

    Dataset_val = MM_BDXJTU2019(root = args.dataset_root, mode = 'val')
    Dataloader_val = data.DataLoader(Dataset_val, batch_size = 8,
                                 num_workers = args.num_workers,
                                 shuffle = True, pin_memory = True)

    criterion = nn.CrossEntropyLoss(weight = weights).cuda()

    print("info:",args.batch_size)
    print("load pretrained model from model3 _____21_6000___adam")
    state_dict1 = torch.load('model3/BDXJTU2019_SGD_21_6000.pth')
    model.load_state_dict(state_dict1, strict=False)

    # Optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = args.lr, momentum = args.momentum,
    #                       weight_decay = args.weight_decay)

    Optimizer=optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=0.001)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch)    #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch%1 == 0:
            1#torch.save(model.state_dict(), 'model3/BDXJTU2019_SGD_' + repr(epoch) + '.pth')
コード例 #2
0
def main():
    #create model
    best_prec1 = 0

    if args.basenet == 'se_resnet152':
        model = MultiModalNet('se_resnet152', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    elif args.basenet == 'se_resnext50_32x4d':
        model = MultiModalNet1('se_resnext50_32x4d', 'DPN26', 0.5)
    elif args.basenet == 'se_resnet50':
        model = MultiModalNet1('se_resnet50', 'DPN26', 0.5)
    elif args.basenet == 'densenet201':
        model = MultiModalNet2('densenet201', 'DPN26', 0.5)

    elif args.basenet == 'oct_resnet101':
        model = oct_resnet101()
#    print("load pretrained model from /home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_4.pth")
#    pre='/home/dell/Desktop/2019BaiduXJTU/weights/densenet201_se_resnext50_32x4d_resample_pretrained_80w_1/BDXJTU2019_SGD_1.pth'
#    model.load_state_dict(torch.load(pre))
#net = Networktorch.nn.DataParallel(Network, device_ids=[0])
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
    model.to(device)
    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                  mode='MM_1_train',
                                  transform=Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train,
                                       128,
                                       num_workers=4,
                                       shuffle=True,
                                       pin_memory=True)

    Dataset_val = MM_BDXJTU2019(root='/home/dell/Desktop/2019BaiduXJTU/data',
                                mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=32,
                                     num_workers=4,
                                     shuffle=True,
                                     pin_memory=True)

    #    criterion = nn.CrossEntropyLoss(weight = weights).cuda()
    criterion = nn.CrossEntropyLoss().to(device)
    #    Optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = args.lr, momentum = args.momentum,
    #                          weight_decay = args.weight_decay)
    Optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch % 1 == 0:
            torch.save(
                model.module.state_dict(), 'weights/' + args.basenet +
                '_se_resnext50_32x4d_resample_pretrained_80w_1/' +
                'BDXJTU2019_SGD_' + repr(epoch) + '.pth')
コード例 #3
0
def main():
    #create model
    best_prec1 = 0

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    if args.basenet == 'MultiModal':
        model = MultiModalNet('se_resnext50_32x4d', 'DPN26', 0.5)
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])

    elif args.basenet == 'oct_resnet101':
        model = oct_resnet101()
        #net = Networktorch.nn.DataParallel(Network, device_ids=[0])

    model = model.cuda()
    cudnn.benchmark = True

    # Dataset
    Aug = Augmentation()
    Dataset_train = MM_BDXJTU2019(root=args.dataset_root,
                                  mode='MM_1_train',
                                  transform=Aug)
    #weights = [class_ration[label] for data,label in Dataset_train]

    Dataloader_train = data.DataLoader(Dataset_train,
                                       args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       pin_memory=True)

    Dataset_val = BDXJTU2019(root=args.dataset_root, mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=8,
                                     num_workers=args.num_workers,
                                     shuffle=True,
                                     pin_memory=True)

    criterion = nn.CrossEntropyLoss(weight=weights).cuda()

    Optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(Optimizer, epoch)

        # train for one epoch
        train(Dataloader_train, model, criterion, Optimizer, epoch
              )  #train(Dataloader_train, Network, criterion, Optimizer, epoch)

        # evaluate on validation set
        #_,_ = validate(Dataloader_val, model, criterion)  #prec1 = validate(Dataloader_val, Network, criterion)

        # remember best prec@1 and save checkpoint
        #is_best = prec1 > best_prec1
        #best_prec1 = max(prec1, best_prec1)
        #if is_best:
        if epoch % 1 == 0:
            torch.save(
                model.state_dict(), 'weights/' + args.basenet +
                '_se_resnext50_32x4d_resample_pretrained_80w_1/' +
                'BDXJTU2019_SGD_' + repr(epoch) + '.pth')
コード例 #4
0
def GeResult():

    # Dataset
    Dataset_val = BDXJTU2019_TTA(root='data', mode='val')
    Dataloader_val = data.DataLoader(Dataset_val,
                                     batch_size=1,
                                     num_workers=4,
                                     shuffle=True,
                                     pin_memory=True)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(0)

    class_names = [
        '001', '002', '003', '004', '005', '006', '007', '008', '009'
    ]

    # construct network
    net = oct_resnet101()
    net.load_state_dict(
        torch.load(
            '/home/zxw/2019BaiduXJTU/weights/oct_resnet101_nor/BDXJTU2019_SGD_20.pth'
        ))

    net.eval()

    results = []
    results_anno = []

    for i, (Input_O, Input_H, Input_V, Anno) in enumerate(Dataloader_val):

        ConfTensor_O = net.forward(Input_O.cuda())
        ConfTensor_H = net.forward(Input_H.cuda())
        ConfTensor_V = net.forward(Input_V.cuda())

        ConfTensor = torch.nn.functional.normalize(
            ConfTensor_O) + torch.nn.functional.normalize(
                ConfTensor_H) + torch.nn.functional.normalize(ConfTensor_V)
        _, pred = ConfTensor.data.topk(1, 1, True, False)

        results.append(pred.item())

        results_anno.append(Anno)  #append annotation results
        if (i % 1000 == 0):
            print(i)
            print('Accuracy of Orignal Input: %0.6f' %
                  (accuracy_score(results, results_anno, normalize=True)))
    # print accuracy of different input
    print('Accuracy of Orignal Input: %0.6f' %
          (accuracy_score(results, results_anno, normalize=True)))

    cnf_matrix = confusion_matrix(results_anno, results)
    cnf_tr = np.trace(cnf_matrix)
    cnf_tr = cnf_tr.astype('float')
    print(cnf_tr / len(Dataset_val))
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          title='Confusion matrix, without normalization')
    plt.figure()
    plot_confusion_matrix(cnf_matrix,
                          classes=class_names,
                          normalize=True,
                          title='Normalized confusion matrix')
    plt.show()